mirror of
https://github.com/KosinskiLab/AlphaPulldown.git
synced 2026-06-04 14:14:24 +08:00
New.
This commit is contained in:
38
test/unit/test_distogram_parser.py
Normal file
38
test/unit/test_distogram_parser.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
import pickle
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import alphapulldown.utils.distogram_parser as distogram_parser_module
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_contacts_returns_empty_list_when_no_pickles_exist(monkeypatch, tmp_path):
|
||||||
|
monkeypatch.setattr(distogram_parser_module, "datadir", str(tmp_path), raising=False)
|
||||||
|
|
||||||
|
parser = distogram_parser_module.distogram_parser()
|
||||||
|
|
||||||
|
assert parser.get_contacts("ignored") == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_contacts_extracts_top_ranked_inter_chain_contact(monkeypatch, tmp_path):
|
||||||
|
logits = np.full((4, 4, 3), -10.0, dtype=np.float32)
|
||||||
|
logits[0, 2, 0] = 10.0
|
||||||
|
logits[2, 0, 0] = 10.0
|
||||||
|
payload = {
|
||||||
|
"ranking_confidence": 0.9,
|
||||||
|
"seqs": ["AA", "BB"],
|
||||||
|
"distogram": {
|
||||||
|
"bin_edges": np.array([4.0, 8.0, 12.0], dtype=np.float32),
|
||||||
|
"logits": logits,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
with open(tmp_path / "result_model.pkl", "wb") as handle:
|
||||||
|
pickle.dump(payload, handle)
|
||||||
|
monkeypatch.setattr(distogram_parser_module, "datadir", str(tmp_path), raising=False)
|
||||||
|
|
||||||
|
parser = distogram_parser_module.distogram_parser()
|
||||||
|
contacts = parser.get_contacts("ignored", distance=9, pbtycutoff=0.5, cross_only=True)
|
||||||
|
|
||||||
|
assert len(contacts) == 1
|
||||||
|
assert contacts[0][0] == (1, "A")
|
||||||
|
assert contacts[0][1] == (1, "B")
|
||||||
|
assert contacts[0][2] > 0.99
|
||||||
118
test/unit/test_file_handling.py
Normal file
118
test/unit/test_file_handling.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from alphapulldown.utils import file_handling
|
||||||
|
|
||||||
|
|
||||||
|
def test_temp_fasta_file_writes_contents_and_cleans_up():
|
||||||
|
with file_handling.temp_fasta_file(">seq\nACDE\n") as fasta_path:
|
||||||
|
path = Path(fasta_path)
|
||||||
|
assert path.is_file()
|
||||||
|
assert path.read_text(encoding="utf-8") == ">seq\nACDE\n"
|
||||||
|
|
||||||
|
assert not path.exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_fasta_description_to_protein_name_sanitizes_symbols():
|
||||||
|
protein_name = file_handling.convert_fasta_description_to_protein_name(
|
||||||
|
">sp|P12345|Protein A:chain;1?"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert protein_name == "sp_P12345_Protein_A_chain_1_"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_fasta_parses_multiple_sequences_and_skips_blank_lines():
|
||||||
|
sequences, descriptions = file_handling.parse_fasta(
|
||||||
|
">protein one\nACD\n\nEF\n>protein|two\nGHI\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert sequences == ["ACDEF", "GHI"]
|
||||||
|
assert descriptions == ["protein_one", "protein_two"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_iter_seqs_yields_sequences_from_multiple_files(tmp_path):
|
||||||
|
fasta_a = tmp_path / "a.fasta"
|
||||||
|
fasta_b = tmp_path / "b.fasta"
|
||||||
|
fasta_a.write_text(">first\nAAA\n", encoding="utf-8")
|
||||||
|
fasta_b.write_text(">second\nBBB\n", encoding="utf-8")
|
||||||
|
|
||||||
|
records = list(file_handling.iter_seqs([str(fasta_a), str(fasta_b)]))
|
||||||
|
|
||||||
|
assert records == [("AAA", "first"), ("BBB", "second")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_dir_monomer_dictionary_maps_files_to_their_source_dirs(tmp_path):
|
||||||
|
dir_a = tmp_path / "a"
|
||||||
|
dir_b = tmp_path / "b"
|
||||||
|
dir_a.mkdir()
|
||||||
|
dir_b.mkdir()
|
||||||
|
(dir_a / "proteinA.pkl").write_text("", encoding="utf-8")
|
||||||
|
(dir_b / "proteinB.pkl.xz").write_text("", encoding="utf-8")
|
||||||
|
|
||||||
|
result = file_handling.make_dir_monomer_dictionary([str(dir_a), str(dir_b)])
|
||||||
|
|
||||||
|
assert result == {
|
||||||
|
"proteinA.pkl": str(dir_a),
|
||||||
|
"proteinB.pkl.xz": str(dir_b),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_csv_file_raises_for_missing_fasta():
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
|
file_handling.parse_csv_file("missing.csv", ["missing.fasta"], "/templates")
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_csv_file_returns_unique_records_without_clustering(tmp_path):
|
||||||
|
fasta = tmp_path / "proteins.fasta"
|
||||||
|
fasta.write_text(">proteinA\nACDE\n", encoding="utf-8")
|
||||||
|
csv_path = tmp_path / "description.csv"
|
||||||
|
csv_path.write_text(
|
||||||
|
"proteinA,template1.cif,A\nproteinA,template2.cif,B\ninvalid,row\nmissing,template3.cif,C\n",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = file_handling.parse_csv_file(
|
||||||
|
str(csv_path), [str(fasta)], str(tmp_path / "templates"), cluster=False
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == [
|
||||||
|
{
|
||||||
|
"protein": "proteinA.template1.cif.A",
|
||||||
|
"sequence": "ACDE",
|
||||||
|
"templates": [str(tmp_path / "templates" / "template1.cif")],
|
||||||
|
"chains": ["A"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"protein": "proteinA.template2.cif.B",
|
||||||
|
"sequence": "ACDE",
|
||||||
|
"templates": [str(tmp_path / "templates" / "template2.cif")],
|
||||||
|
"chains": ["B"],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_csv_file_clusters_multiple_templates_per_protein(tmp_path):
|
||||||
|
fasta = tmp_path / "proteins.fasta"
|
||||||
|
fasta.write_text(">proteinA\nACDE\n", encoding="utf-8")
|
||||||
|
csv_path = tmp_path / "description.csv"
|
||||||
|
csv_path.write_text(
|
||||||
|
"proteinA,template1.cif,A\nproteinA,template2.cif,B\n",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = file_handling.parse_csv_file(
|
||||||
|
str(csv_path), [str(fasta)], str(tmp_path / "templates"), cluster=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == [
|
||||||
|
{
|
||||||
|
"protein": "proteinA",
|
||||||
|
"sequence": "ACDE",
|
||||||
|
"templates": [
|
||||||
|
str(tmp_path / "templates" / "template1.cif"),
|
||||||
|
str(tmp_path / "templates" / "template2.cif"),
|
||||||
|
],
|
||||||
|
"chains": ["A", "B"],
|
||||||
|
}
|
||||||
|
]
|
||||||
17
test/unit/test_plotting.py
Normal file
17
test/unit/test_plotting.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from alphapulldown.utils.plotting import plot_pae_from_matrix
|
||||||
|
|
||||||
|
|
||||||
|
def test_plot_pae_from_matrix_writes_png_file(tmp_path):
|
||||||
|
output = tmp_path / "pae.png"
|
||||||
|
|
||||||
|
plot_pae_from_matrix(
|
||||||
|
seqs=["AA", "BBB"],
|
||||||
|
pae_matrix=np.arange(25, dtype=float).reshape(5, 5),
|
||||||
|
figure_name=str(output),
|
||||||
|
ranking=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert output.is_file()
|
||||||
|
assert output.stat().st_size > 0
|
||||||
94
test/unit/test_post_modelling.py
Normal file
94
test/unit/test_post_modelling.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
import gzip
|
||||||
|
import json
|
||||||
|
import pickle
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from alphapulldown.utils import post_modelling
|
||||||
|
|
||||||
|
|
||||||
|
def test_compress_file_gzips_and_removes_original(tmp_path):
|
||||||
|
file_path = tmp_path / "result.pkl"
|
||||||
|
file_path.write_bytes(b"payload")
|
||||||
|
|
||||||
|
gz_path = post_modelling.compress_file(str(file_path))
|
||||||
|
|
||||||
|
assert gz_path == str(file_path) + ".gz"
|
||||||
|
assert not file_path.exists()
|
||||||
|
with gzip.open(gz_path, "rb") as handle:
|
||||||
|
assert handle.read() == b"payload"
|
||||||
|
|
||||||
|
|
||||||
|
def test_compress_result_pickles_only_compresses_pickle_files(tmp_path):
|
||||||
|
pickle_path = tmp_path / "result.pkl"
|
||||||
|
text_path = tmp_path / "notes.txt"
|
||||||
|
pickle_path.write_bytes(b"payload")
|
||||||
|
text_path.write_text("keep", encoding="utf-8")
|
||||||
|
|
||||||
|
post_modelling.compress_result_pickles(str(tmp_path))
|
||||||
|
|
||||||
|
assert not pickle_path.exists()
|
||||||
|
assert (tmp_path / "result.pkl.gz").is_file()
|
||||||
|
assert text_path.is_file()
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_keys_from_pickle_updates_pickle_in_place(tmp_path):
|
||||||
|
file_path = tmp_path / "result.pkl"
|
||||||
|
with open(file_path, "wb") as handle:
|
||||||
|
pickle.dump({"keep": 1, "distogram": 2, "masked_msa": 3}, handle)
|
||||||
|
|
||||||
|
post_modelling.remove_keys_from_pickle(str(file_path), ["distogram", "missing"])
|
||||||
|
|
||||||
|
with open(file_path, "rb") as handle:
|
||||||
|
payload = pickle.load(handle)
|
||||||
|
|
||||||
|
assert payload == {"keep": 1, "masked_msa": 3}
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_irrelevant_pickles_keeps_only_best_pickle(tmp_path):
|
||||||
|
best = tmp_path / "result_model_1.pkl"
|
||||||
|
other = tmp_path / "result_model_2.pkl"
|
||||||
|
note = tmp_path / "note.txt"
|
||||||
|
best.write_bytes(b"best")
|
||||||
|
other.write_bytes(b"other")
|
||||||
|
note.write_text("keep", encoding="utf-8")
|
||||||
|
|
||||||
|
post_modelling.remove_irrelevant_pickles(str(tmp_path), best.name)
|
||||||
|
|
||||||
|
assert best.is_file()
|
||||||
|
assert not other.exists()
|
||||||
|
assert note.is_file()
|
||||||
|
|
||||||
|
|
||||||
|
def test_post_prediction_process_remove_keys_then_compress_and_prune(tmp_path):
|
||||||
|
ranking_debug = tmp_path / "ranking_debug.json"
|
||||||
|
ranking_debug.write_text(json.dumps({"order": ["model_1"]}), encoding="utf-8")
|
||||||
|
result_best = tmp_path / "result_model_1.pkl"
|
||||||
|
result_other = tmp_path / "result_model_2.pkl"
|
||||||
|
with open(result_best, "wb") as handle:
|
||||||
|
pickle.dump({"keep": 1, "distogram": 2, "masked_msa": 3}, handle)
|
||||||
|
with open(result_other, "wb") as handle:
|
||||||
|
pickle.dump({"keep": 4, "aligned_confidence_probs": 5}, handle)
|
||||||
|
|
||||||
|
post_modelling.post_prediction_process(
|
||||||
|
str(tmp_path),
|
||||||
|
compress_pickles=True,
|
||||||
|
remove_pickles=True,
|
||||||
|
remove_keys=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not result_best.exists()
|
||||||
|
assert not result_other.exists()
|
||||||
|
gz_best = tmp_path / "result_model_1.pkl.gz"
|
||||||
|
assert gz_best.is_file()
|
||||||
|
with gzip.open(gz_best, "rb") as handle:
|
||||||
|
payload = pickle.load(handle)
|
||||||
|
assert payload == {"keep": 1}
|
||||||
|
|
||||||
|
|
||||||
|
def test_post_prediction_process_handles_missing_ranking_debug(tmp_path):
|
||||||
|
lonely_pickle = tmp_path / "result_model_1.pkl"
|
||||||
|
lonely_pickle.write_bytes(b"payload")
|
||||||
|
|
||||||
|
post_modelling.post_prediction_process(str(tmp_path), compress_pickles=True)
|
||||||
|
|
||||||
|
assert lonely_pickle.is_file()
|
||||||
Reference in New Issue
Block a user