This commit is contained in:
Dima
2026-03-27 15:12:16 +01:00
parent 7fd121f842
commit 407764479b
3 changed files with 495 additions and 0 deletions

View File

@@ -0,0 +1,233 @@
import gzip
import json
import math
import pickle
import warnings
from pathlib import Path
import numpy as np
import pytest
from Bio.PDB.PDBExceptions import PDBConstructionWarning
from alphapulldown.analysis_pipeline import calculate_mpdockq as mpdockq
def _atom_line(
serial: int,
atom_name: str,
res_name: str,
chain_id: str,
residue_number: int,
x: float,
y: float,
z: float,
*,
occupancy: float = 1.0,
bfactor: float = 50.0,
) -> str:
element = atom_name.strip()[0]
return (
f"ATOM {serial:5d} {atom_name:<4}{' ':1}{res_name:>3} {chain_id:1}"
f"{residue_number:4d}{' ':1} {x:8.3f}{y:8.3f}{z:8.3f}"
f"{occupancy:6.2f}{bfactor:6.2f} {element:>2}\n"
)
def _write_pdb(path: Path, atoms: list[str]) -> Path:
path.write_text("".join(atoms) + "TER\nEND\n", encoding="utf-8")
return path
def test_parse_atm_record_extracts_fixed_width_fields():
line = _atom_line(12, "CB", "SER", "B", 7, 1.5, 2.5, 3.5, occupancy=0.5, bfactor=42.0)
record = mpdockq.parse_atm_record(line)
assert record["name"] == "ATOM"
assert record["atm_no"] == 12
assert record["atm_name"] == "CB"
assert record["res_name"] == "SER"
assert record["chain"] == "B"
assert record["res_no"] == 7
assert record["x"] == pytest.approx(1.5)
assert record["y"] == pytest.approx(2.5)
assert record["z"] == pytest.approx(3.5)
assert record["occ"] == pytest.approx(0.5)
assert record["B"] == pytest.approx(42.0)
def test_read_pdb_extracts_chain_coordinates_and_ca_cb_indices(tmp_path):
pdb_path = _write_pdb(
tmp_path / "chains.pdb",
[
_atom_line(1, "N", "ALA", "A", 1, 0.0, 0.0, 0.0),
_atom_line(2, "CA", "ALA", "A", 1, 1.0, 0.0, 0.0),
_atom_line(3, "CB", "ALA", "A", 1, 1.0, 1.0, 0.0),
_atom_line(4, "N", "GLY", "B", 1, 5.0, 0.0, 0.0),
_atom_line(5, "CA", "GLY", "B", 1, 6.0, 0.0, 0.0),
],
)
pdb_chains, chain_coords, chain_ca_inds, chain_cb_inds = mpdockq.read_pdb(
str(pdb_path)
)
assert sorted(pdb_chains) == ["A", "B"]
assert len(chain_coords["A"]) == 3
assert len(chain_coords["B"]) == 2
assert chain_ca_inds == {"A": [1], "B": [1]}
assert chain_cb_inds == {"A": [2], "B": [1]}
def test_parse_bfactor_averages_residue_bfactors(tmp_path):
pdb_path = _write_pdb(
tmp_path / "bfactor.pdb",
[
_atom_line(1, "N", "ALA", "A", 1, 0.0, 0.0, 0.0, bfactor=10.0),
_atom_line(2, "CA", "ALA", "A", 1, 1.0, 0.0, 0.0, bfactor=20.0),
_atom_line(3, "N", "GLY", "A", 2, 2.0, 0.0, 0.0, bfactor=40.0),
_atom_line(4, "CA", "GLY", "A", 2, 3.0, 0.0, 0.0, bfactor=60.0),
],
)
with warnings.catch_warnings():
warnings.simplefilter("ignore", PDBConstructionWarning)
residue_bfactors = mpdockq.parse_bfactor(str(pdb_path))
np.testing.assert_allclose(residue_bfactors, np.asarray([15.0, 50.0]))
def test_get_best_plddt_prefers_uncompressed_pickle(tmp_path):
(tmp_path / "ranking_debug.json").write_text(
json.dumps({"order": ["model_1"]}),
encoding="utf-8",
)
expected = np.asarray([91.0, 92.0])
with open(tmp_path / "result_model_1.pkl", "wb") as handle:
pickle.dump({"plddt": expected}, handle)
plddt = mpdockq.get_best_plddt(str(tmp_path))
np.testing.assert_array_equal(plddt, expected)
def test_get_best_plddt_falls_back_to_gzipped_pickle(tmp_path):
(tmp_path / "ranking_debug.json").write_text(
json.dumps({"order": ["model_2"]}),
encoding="utf-8",
)
expected = np.asarray([81.0, 82.0])
with gzip.open(tmp_path / "result_model_2.pkl.gz", "wb") as handle:
pickle.dump({"plddt": expected}, handle)
plddt = mpdockq.get_best_plddt(str(tmp_path))
np.testing.assert_array_equal(plddt, expected)
def test_get_best_plddt_falls_back_to_ranked_pdb_bfactors(monkeypatch, tmp_path):
(tmp_path / "ranking_debug.json").write_text(
json.dumps({"order": ["model_3"]}),
encoding="utf-8",
)
(tmp_path / "ranked_0.pdb").write_text("HEADER\n", encoding="utf-8")
expected = np.asarray([71.0, 72.0])
monkeypatch.setattr(mpdockq, "parse_bfactor", lambda pdb_path: expected)
plddt = mpdockq.get_best_plddt(str(tmp_path))
np.testing.assert_array_equal(plddt, expected)
def test_read_plddt_slices_values_per_chain():
best_plddt = np.asarray([90.0, 91.0, 50.0])
chain_ca_inds = {"A": [0, 1], "B": [0]}
per_chain = mpdockq.read_plddt(best_plddt, chain_ca_inds)
np.testing.assert_array_equal(per_chain["A"], np.asarray([90.0, 91.0]))
np.testing.assert_array_equal(per_chain["B"], np.asarray([50.0]))
def test_score_complex_returns_zero_when_no_interface_contacts():
complex_score, chain_count = mpdockq.score_complex(
{
"A": [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]],
"B": [[30.0, 0.0, 0.0], [31.0, 0.0, 0.0]],
},
{"A": [1], "B": [1]},
{"A": np.asarray([80.0]), "B": np.asarray([85.0])},
)
assert complex_score == 0
assert chain_count == 2
def test_score_complex_and_mpdockq_are_positive_for_contacting_chains():
complex_score, chain_count = mpdockq.score_complex(
{
"A": [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]],
"B": [[0.0, 0.0, 5.0], [1.0, 0.0, 5.0]],
},
{"A": [1], "B": [1]},
{"A": np.asarray([90.0]), "B": np.asarray([90.0])},
)
mpdockq_score = mpdockq.calculate_mpDockQ(complex_score)
expected = 0.827 / (1 + math.exp(-0.036 * (complex_score - 261.398))) + 0.221
assert chain_count == 2
assert complex_score > 0
assert mpdockq_score == pytest.approx(expected)
def test_read_pdb_pdockq_uses_cb_and_gly_ca_coordinates(tmp_path):
pdb_path = _write_pdb(
tmp_path / "pdockq.pdb",
[
_atom_line(1, "N", "ALA", "A", 1, 0.0, 0.0, 0.0, bfactor=20.0),
_atom_line(2, "CA", "ALA", "A", 1, 1.0, 0.0, 0.0, bfactor=30.0),
_atom_line(3, "CB", "ALA", "A", 1, 1.0, 1.0, 0.0, bfactor=40.0),
_atom_line(4, "N", "GLY", "B", 1, 5.0, 0.0, 0.0, bfactor=50.0),
_atom_line(5, "CA", "GLY", "B", 1, 6.0, 0.0, 0.0, bfactor=60.0),
],
)
chain_coords, chain_plddt = mpdockq.read_pdb_pdockq(str(pdb_path))
np.testing.assert_array_equal(chain_coords["A"], np.asarray([[1.0, 1.0, 0.0]]))
np.testing.assert_array_equal(chain_coords["B"], np.asarray([[6.0, 0.0, 0.0]]))
np.testing.assert_array_equal(chain_plddt["A"], np.asarray([40.0]))
np.testing.assert_array_equal(chain_plddt["B"], np.asarray([60.0]))
def test_calc_pdockq_returns_zero_without_contacts():
score = mpdockq.calc_pdockq(
{
"A": np.asarray([[0.0, 0.0, 0.0]]),
"B": np.asarray([[20.0, 0.0, 0.0]]),
},
{
"A": np.asarray([90.0]),
"B": np.asarray([80.0]),
},
8,
)
assert score == 0
def test_calc_pdockq_returns_positive_score_for_contacting_chains():
score = mpdockq.calc_pdockq(
{
"A": np.asarray([[0.0, 0.0, 0.0], [0.0, 3.0, 0.0]]),
"B": np.asarray([[0.0, 0.0, 5.0], [0.0, 3.0, 5.0]]),
},
{
"A": np.asarray([90.0, 92.0]),
"B": np.asarray([88.0, 86.0]),
},
8,
)
assert 0 < score < 1

View File

@@ -0,0 +1,198 @@
import logging
from pathlib import Path
import pytest
from Bio.PDB import PDBParser
from alphapulldown.utils import calculate_rmsd
def _atom_line(
serial: int,
atom_name: str,
res_name: str,
chain_id: str,
residue_number: int,
x: float,
y: float,
z: float,
*,
bfactor: float = 20.0,
) -> str:
element = atom_name.strip()[0]
return (
f"ATOM {serial:5d} {atom_name:<4}{' ':1}{res_name:>3} {chain_id:1}"
f"{residue_number:4d}{' ':1} {x:8.3f}{y:8.3f}{z:8.3f}"
f"{1.00:6.2f}{bfactor:6.2f} {element:>2}\n"
)
def _write_pdb(path: Path, atoms: list[str]) -> Path:
path.write_text("".join(atoms) + "TER\nEND\n", encoding="utf-8")
return path
def _parse_structure(path: Path):
return PDBParser(QUIET=True).get_structure(path.stem, str(path))
def test_setup_logging_configures_basic_config(monkeypatch):
calls = {}
def fake_basic_config(**kwargs):
calls.update(kwargs)
monkeypatch.setattr(logging, "basicConfig", fake_basic_config)
calculate_rmsd.setup_logging()
assert calls == {
"format": "%(asctime)s - %(levelname)s: %(message)s",
"level": logging.INFO,
}
def test_extract_ca_sequence_skips_missing_ca_and_marks_unknown_residues(tmp_path):
pdb_path = _write_pdb(
tmp_path / "sequence.pdb",
[
_atom_line(1, "N", "ALA", "A", 1, 0.0, 0.0, 0.0),
_atom_line(2, "CA", "ALA", "A", 1, 1.0, 0.0, 0.0),
_atom_line(3, "N", "UNK", "A", 2, 2.0, 0.0, 0.0),
_atom_line(4, "CA", "UNK", "A", 2, 3.0, 0.0, 0.0),
_atom_line(5, "N", "GLY", "A", 3, 4.0, 0.0, 0.0),
],
)
sequence = calculate_rmsd.extract_ca_sequence(_parse_structure(pdb_path))
assert sequence == "A-"
def test_align_sequences_returns_global_alignment_for_identical_sequences():
alignment = calculate_rmsd.align_sequences("ACD", "ACD")
assert alignment.score == 6
assert alignment.target == "ACD"
assert alignment.query == "ACD"
def test_get_common_atoms_returns_only_shared_atom_ids(tmp_path):
ref_path = _write_pdb(
tmp_path / "ref_atoms.pdb",
[
_atom_line(1, "N", "ALA", "A", 1, 0.0, 0.0, 0.0),
_atom_line(2, "CA", "ALA", "A", 1, 1.0, 0.0, 0.0),
_atom_line(3, "CB", "ALA", "A", 1, 1.0, 1.0, 0.0),
],
)
target_path = _write_pdb(
tmp_path / "target_atoms.pdb",
[
_atom_line(1, "CA", "ALA", "A", 1, 10.0, 0.0, 0.0),
_atom_line(2, "C", "ALA", "A", 1, 11.0, 0.0, 0.0),
_atom_line(3, "O", "ALA", "A", 1, 12.0, 0.0, 0.0),
],
)
ref_res = next(_parse_structure(ref_path).get_residues())
target_res = next(_parse_structure(target_path).get_residues())
common_atoms = calculate_rmsd.get_common_atoms(ref_res, target_res)
assert [ref_atom.get_id() for ref_atom, _ in common_atoms] == ["CA"]
def test_process_chain_collects_common_atoms_from_matching_residues(tmp_path):
ref_path = _write_pdb(
tmp_path / "ref_chain.pdb",
[
_atom_line(1, "N", "ALA", "A", 1, 0.0, 0.0, 0.0),
_atom_line(2, "CA", "ALA", "A", 1, 1.0, 0.0, 0.0),
_atom_line(3, "CB", "ALA", "A", 1, 1.0, 1.0, 0.0),
_atom_line(4, "N", "GLY", "A", 2, 2.0, 0.0, 0.0),
_atom_line(5, "CA", "GLY", "A", 2, 3.0, 0.0, 0.0),
],
)
target_path = _write_pdb(
tmp_path / "target_chain.pdb",
[
_atom_line(1, "N", "ALA", "A", 1, 10.0, 0.0, 0.0),
_atom_line(2, "CA", "ALA", "A", 1, 11.0, 0.0, 0.0),
_atom_line(3, "CA", "GLY", "A", 2, 13.0, 0.0, 0.0),
],
)
ref_structure = _parse_structure(ref_path)
target_structure = _parse_structure(target_path)
alignment = calculate_rmsd.align_sequences("AG", "AG")
ref_atoms, target_atoms = calculate_rmsd.process_chain(
"A",
ref_structure,
target_structure,
alignment,
)
assert [atom.get_id() for atom in ref_atoms] == ["N", "CA", "CA"]
assert [atom.get_id() for atom in target_atoms] == ["N", "CA", "CA"]
def test_calculate_rmsd_and_superpose_returns_rmsd_and_writes_outputs(tmp_path):
ref_path = _write_pdb(
tmp_path / "ref.pdb",
[
_atom_line(1, "N", "ALA", "A", 1, 0.0, 0.0, 0.0),
_atom_line(2, "CA", "ALA", "A", 1, 1.0, 0.0, 0.0),
_atom_line(3, "C", "ALA", "A", 1, 2.0, 0.0, 0.0),
_atom_line(4, "N", "GLY", "A", 2, 3.0, 0.0, 0.0),
_atom_line(5, "CA", "GLY", "A", 2, 4.0, 0.0, 0.0),
_atom_line(6, "C", "GLY", "A", 2, 5.0, 0.0, 0.0),
],
)
target_path = _write_pdb(
tmp_path / "target.pdb",
[
_atom_line(1, "N", "ALA", "A", 1, 10.0, 0.0, 0.0),
_atom_line(2, "CA", "ALA", "A", 1, 11.0, 0.0, 0.0),
_atom_line(3, "C", "ALA", "A", 1, 12.0, 0.0, 0.0),
_atom_line(4, "N", "GLY", "A", 2, 13.0, 0.0, 0.0),
_atom_line(5, "CA", "GLY", "A", 2, 14.0, 0.0, 0.0),
_atom_line(6, "C", "GLY", "A", 2, 15.0, 0.0, 0.0),
],
)
rmsd = calculate_rmsd.calculate_rmsd_and_superpose(
str(ref_path),
str(target_path),
temp_dir=str(tmp_path),
)
assert rmsd == pytest.approx(0.0)
assert (tmp_path / "superposed_ref.pdb").is_file()
assert (tmp_path / "superposed_target.pdb").is_file()
def test_calculate_rmsd_and_superpose_returns_none_without_matching_chains(
tmp_path,
caplog,
):
ref_path = _write_pdb(
tmp_path / "ref_nomatch.pdb",
[_atom_line(1, "CA", "ALA", "A", 1, 0.0, 0.0, 0.0)],
)
target_path = _write_pdb(
tmp_path / "target_nomatch.pdb",
[_atom_line(1, "CA", "ALA", "B", 1, 10.0, 0.0, 0.0)],
)
with caplog.at_level(logging.ERROR):
rmsd = calculate_rmsd.calculate_rmsd_and_superpose(
str(ref_path),
str(target_path),
temp_dir=str(tmp_path),
)
assert rmsd is None
assert "No suitable atoms found for RMSD calculation." in caplog.text

View File

@@ -0,0 +1,64 @@
import numpy as np
from alphapulldown.utils import msa_encoding
def test_get_id_to_char_map_includes_gap_and_unknown():
mapping = msa_encoding.get_id_to_char_map()
assert mapping[21] == "-"
assert mapping[20] == "X"
def test_get_char_to_id_map_inverts_gap_and_standard_residue():
id_to_char = msa_encoding.get_id_to_char_map()
char_to_id = msa_encoding.get_char_to_id_map()
assert char_to_id["-"] == 21
assert char_to_id[id_to_char[0]] == 0
def test_ids_to_a3m_emits_expected_headers_and_rows():
rows = np.asarray([[0, 6, 21], [20, 5, 4]], dtype=np.int32)
id_to_char = msa_encoding.get_id_to_char_map()
text = msa_encoding.ids_to_a3m(rows)
expected = (
">sequence_0\n"
f"{id_to_char[0]}{id_to_char[6]}-\n"
">sequence_1\n"
f"{id_to_char[20]}{id_to_char[5]}{id_to_char[4]}\n"
)
assert text == expected
def test_a3m_to_ids_strips_insertions_and_maps_unknowns():
char_to_id = msa_encoding.get_char_to_id_map()
a3m = ">sequence_0\nAcd-\n>sequence_1\nZx-\n"
rows = msa_encoding.a3m_to_ids(a3m)
expected = np.asarray(
[
[char_to_id["A"], char_to_id["-"]],
[20, char_to_id["-"]],
],
dtype=np.int32,
)
np.testing.assert_array_equal(rows, expected)
def test_a3m_to_ids_returns_empty_matrix_for_empty_input():
rows = msa_encoding.a3m_to_ids("")
assert rows.shape == (0, 0)
assert rows.dtype == np.int32
def test_ids_to_a3m_af3_uses_af3_alphabet_and_unknown_fallback():
rows = np.asarray([[0, 21, 22, 29, 30, 99]], dtype=np.int32)
text = msa_encoding.ids_to_a3m_af3(rows)
assert text == ">sequence_0\nA-ATNX\n"