mirror of
https://github.com/KosinskiLab/AlphaPulldown.git
synced 2026-06-04 14:14:24 +08:00
New.
This commit is contained in:
38
test/unit/test_distogram_parser.py
Normal file
38
test/unit/test_distogram_parser.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
|
||||
import alphapulldown.utils.distogram_parser as distogram_parser_module
|
||||
|
||||
|
||||
def test_get_contacts_returns_empty_list_when_no_pickles_exist(monkeypatch, tmp_path):
|
||||
monkeypatch.setattr(distogram_parser_module, "datadir", str(tmp_path), raising=False)
|
||||
|
||||
parser = distogram_parser_module.distogram_parser()
|
||||
|
||||
assert parser.get_contacts("ignored") == []
|
||||
|
||||
|
||||
def test_get_contacts_extracts_top_ranked_inter_chain_contact(monkeypatch, tmp_path):
|
||||
logits = np.full((4, 4, 3), -10.0, dtype=np.float32)
|
||||
logits[0, 2, 0] = 10.0
|
||||
logits[2, 0, 0] = 10.0
|
||||
payload = {
|
||||
"ranking_confidence": 0.9,
|
||||
"seqs": ["AA", "BB"],
|
||||
"distogram": {
|
||||
"bin_edges": np.array([4.0, 8.0, 12.0], dtype=np.float32),
|
||||
"logits": logits,
|
||||
},
|
||||
}
|
||||
with open(tmp_path / "result_model.pkl", "wb") as handle:
|
||||
pickle.dump(payload, handle)
|
||||
monkeypatch.setattr(distogram_parser_module, "datadir", str(tmp_path), raising=False)
|
||||
|
||||
parser = distogram_parser_module.distogram_parser()
|
||||
contacts = parser.get_contacts("ignored", distance=9, pbtycutoff=0.5, cross_only=True)
|
||||
|
||||
assert len(contacts) == 1
|
||||
assert contacts[0][0] == (1, "A")
|
||||
assert contacts[0][1] == (1, "B")
|
||||
assert contacts[0][2] > 0.99
|
||||
118
test/unit/test_file_handling.py
Normal file
118
test/unit/test_file_handling.py
Normal file
@@ -0,0 +1,118 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from alphapulldown.utils import file_handling
|
||||
|
||||
|
||||
def test_temp_fasta_file_writes_contents_and_cleans_up():
|
||||
with file_handling.temp_fasta_file(">seq\nACDE\n") as fasta_path:
|
||||
path = Path(fasta_path)
|
||||
assert path.is_file()
|
||||
assert path.read_text(encoding="utf-8") == ">seq\nACDE\n"
|
||||
|
||||
assert not path.exists()
|
||||
|
||||
|
||||
def test_convert_fasta_description_to_protein_name_sanitizes_symbols():
|
||||
protein_name = file_handling.convert_fasta_description_to_protein_name(
|
||||
">sp|P12345|Protein A:chain;1?"
|
||||
)
|
||||
|
||||
assert protein_name == "sp_P12345_Protein_A_chain_1_"
|
||||
|
||||
|
||||
def test_parse_fasta_parses_multiple_sequences_and_skips_blank_lines():
|
||||
sequences, descriptions = file_handling.parse_fasta(
|
||||
">protein one\nACD\n\nEF\n>protein|two\nGHI\n"
|
||||
)
|
||||
|
||||
assert sequences == ["ACDEF", "GHI"]
|
||||
assert descriptions == ["protein_one", "protein_two"]
|
||||
|
||||
|
||||
def test_iter_seqs_yields_sequences_from_multiple_files(tmp_path):
|
||||
fasta_a = tmp_path / "a.fasta"
|
||||
fasta_b = tmp_path / "b.fasta"
|
||||
fasta_a.write_text(">first\nAAA\n", encoding="utf-8")
|
||||
fasta_b.write_text(">second\nBBB\n", encoding="utf-8")
|
||||
|
||||
records = list(file_handling.iter_seqs([str(fasta_a), str(fasta_b)]))
|
||||
|
||||
assert records == [("AAA", "first"), ("BBB", "second")]
|
||||
|
||||
|
||||
def test_make_dir_monomer_dictionary_maps_files_to_their_source_dirs(tmp_path):
|
||||
dir_a = tmp_path / "a"
|
||||
dir_b = tmp_path / "b"
|
||||
dir_a.mkdir()
|
||||
dir_b.mkdir()
|
||||
(dir_a / "proteinA.pkl").write_text("", encoding="utf-8")
|
||||
(dir_b / "proteinB.pkl.xz").write_text("", encoding="utf-8")
|
||||
|
||||
result = file_handling.make_dir_monomer_dictionary([str(dir_a), str(dir_b)])
|
||||
|
||||
assert result == {
|
||||
"proteinA.pkl": str(dir_a),
|
||||
"proteinB.pkl.xz": str(dir_b),
|
||||
}
|
||||
|
||||
|
||||
def test_parse_csv_file_raises_for_missing_fasta():
|
||||
with pytest.raises(FileNotFoundError):
|
||||
file_handling.parse_csv_file("missing.csv", ["missing.fasta"], "/templates")
|
||||
|
||||
|
||||
def test_parse_csv_file_returns_unique_records_without_clustering(tmp_path):
|
||||
fasta = tmp_path / "proteins.fasta"
|
||||
fasta.write_text(">proteinA\nACDE\n", encoding="utf-8")
|
||||
csv_path = tmp_path / "description.csv"
|
||||
csv_path.write_text(
|
||||
"proteinA,template1.cif,A\nproteinA,template2.cif,B\ninvalid,row\nmissing,template3.cif,C\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
result = file_handling.parse_csv_file(
|
||||
str(csv_path), [str(fasta)], str(tmp_path / "templates"), cluster=False
|
||||
)
|
||||
|
||||
assert result == [
|
||||
{
|
||||
"protein": "proteinA.template1.cif.A",
|
||||
"sequence": "ACDE",
|
||||
"templates": [str(tmp_path / "templates" / "template1.cif")],
|
||||
"chains": ["A"],
|
||||
},
|
||||
{
|
||||
"protein": "proteinA.template2.cif.B",
|
||||
"sequence": "ACDE",
|
||||
"templates": [str(tmp_path / "templates" / "template2.cif")],
|
||||
"chains": ["B"],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_parse_csv_file_clusters_multiple_templates_per_protein(tmp_path):
|
||||
fasta = tmp_path / "proteins.fasta"
|
||||
fasta.write_text(">proteinA\nACDE\n", encoding="utf-8")
|
||||
csv_path = tmp_path / "description.csv"
|
||||
csv_path.write_text(
|
||||
"proteinA,template1.cif,A\nproteinA,template2.cif,B\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
result = file_handling.parse_csv_file(
|
||||
str(csv_path), [str(fasta)], str(tmp_path / "templates"), cluster=True
|
||||
)
|
||||
|
||||
assert result == [
|
||||
{
|
||||
"protein": "proteinA",
|
||||
"sequence": "ACDE",
|
||||
"templates": [
|
||||
str(tmp_path / "templates" / "template1.cif"),
|
||||
str(tmp_path / "templates" / "template2.cif"),
|
||||
],
|
||||
"chains": ["A", "B"],
|
||||
}
|
||||
]
|
||||
17
test/unit/test_plotting.py
Normal file
17
test/unit/test_plotting.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import numpy as np
|
||||
|
||||
from alphapulldown.utils.plotting import plot_pae_from_matrix
|
||||
|
||||
|
||||
def test_plot_pae_from_matrix_writes_png_file(tmp_path):
|
||||
output = tmp_path / "pae.png"
|
||||
|
||||
plot_pae_from_matrix(
|
||||
seqs=["AA", "BBB"],
|
||||
pae_matrix=np.arange(25, dtype=float).reshape(5, 5),
|
||||
figure_name=str(output),
|
||||
ranking=2,
|
||||
)
|
||||
|
||||
assert output.is_file()
|
||||
assert output.stat().st_size > 0
|
||||
94
test/unit/test_post_modelling.py
Normal file
94
test/unit/test_post_modelling.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import gzip
|
||||
import json
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
from alphapulldown.utils import post_modelling
|
||||
|
||||
|
||||
def test_compress_file_gzips_and_removes_original(tmp_path):
|
||||
file_path = tmp_path / "result.pkl"
|
||||
file_path.write_bytes(b"payload")
|
||||
|
||||
gz_path = post_modelling.compress_file(str(file_path))
|
||||
|
||||
assert gz_path == str(file_path) + ".gz"
|
||||
assert not file_path.exists()
|
||||
with gzip.open(gz_path, "rb") as handle:
|
||||
assert handle.read() == b"payload"
|
||||
|
||||
|
||||
def test_compress_result_pickles_only_compresses_pickle_files(tmp_path):
|
||||
pickle_path = tmp_path / "result.pkl"
|
||||
text_path = tmp_path / "notes.txt"
|
||||
pickle_path.write_bytes(b"payload")
|
||||
text_path.write_text("keep", encoding="utf-8")
|
||||
|
||||
post_modelling.compress_result_pickles(str(tmp_path))
|
||||
|
||||
assert not pickle_path.exists()
|
||||
assert (tmp_path / "result.pkl.gz").is_file()
|
||||
assert text_path.is_file()
|
||||
|
||||
|
||||
def test_remove_keys_from_pickle_updates_pickle_in_place(tmp_path):
|
||||
file_path = tmp_path / "result.pkl"
|
||||
with open(file_path, "wb") as handle:
|
||||
pickle.dump({"keep": 1, "distogram": 2, "masked_msa": 3}, handle)
|
||||
|
||||
post_modelling.remove_keys_from_pickle(str(file_path), ["distogram", "missing"])
|
||||
|
||||
with open(file_path, "rb") as handle:
|
||||
payload = pickle.load(handle)
|
||||
|
||||
assert payload == {"keep": 1, "masked_msa": 3}
|
||||
|
||||
|
||||
def test_remove_irrelevant_pickles_keeps_only_best_pickle(tmp_path):
|
||||
best = tmp_path / "result_model_1.pkl"
|
||||
other = tmp_path / "result_model_2.pkl"
|
||||
note = tmp_path / "note.txt"
|
||||
best.write_bytes(b"best")
|
||||
other.write_bytes(b"other")
|
||||
note.write_text("keep", encoding="utf-8")
|
||||
|
||||
post_modelling.remove_irrelevant_pickles(str(tmp_path), best.name)
|
||||
|
||||
assert best.is_file()
|
||||
assert not other.exists()
|
||||
assert note.is_file()
|
||||
|
||||
|
||||
def test_post_prediction_process_remove_keys_then_compress_and_prune(tmp_path):
|
||||
ranking_debug = tmp_path / "ranking_debug.json"
|
||||
ranking_debug.write_text(json.dumps({"order": ["model_1"]}), encoding="utf-8")
|
||||
result_best = tmp_path / "result_model_1.pkl"
|
||||
result_other = tmp_path / "result_model_2.pkl"
|
||||
with open(result_best, "wb") as handle:
|
||||
pickle.dump({"keep": 1, "distogram": 2, "masked_msa": 3}, handle)
|
||||
with open(result_other, "wb") as handle:
|
||||
pickle.dump({"keep": 4, "aligned_confidence_probs": 5}, handle)
|
||||
|
||||
post_modelling.post_prediction_process(
|
||||
str(tmp_path),
|
||||
compress_pickles=True,
|
||||
remove_pickles=True,
|
||||
remove_keys=True,
|
||||
)
|
||||
|
||||
assert not result_best.exists()
|
||||
assert not result_other.exists()
|
||||
gz_best = tmp_path / "result_model_1.pkl.gz"
|
||||
assert gz_best.is_file()
|
||||
with gzip.open(gz_best, "rb") as handle:
|
||||
payload = pickle.load(handle)
|
||||
assert payload == {"keep": 1}
|
||||
|
||||
|
||||
def test_post_prediction_process_handles_missing_ranking_debug(tmp_path):
|
||||
lonely_pickle = tmp_path / "result_model_1.pkl"
|
||||
lonely_pickle.write_bytes(b"payload")
|
||||
|
||||
post_modelling.post_prediction_process(str(tmp_path), compress_pickles=True)
|
||||
|
||||
assert lonely_pickle.is_file()
|
||||
Reference in New Issue
Block a user