Files
D-SCRIPT/dscript/tests/test_foldseek.py
2025-08-12 14:05:49 +02:00

226 lines
7.7 KiB
Python

import torch
from loguru import logger
from dscript.foldseek import (
fold_vocab,
get_3di_sequences,
get_foldseek_onehot,
)
class TestFoldseek:
"""Test cases for foldseek.py module."""
def test_fold_vocab_completeness(self):
"""Test that fold_vocab contains expected amino acid codes."""
expected_codes = {
"D",
"P",
"V",
"Q",
"A",
"W",
"K",
"E",
"I",
"T",
"L",
"F",
"G",
"S",
"M",
"H",
"C",
"R",
"Y",
"N",
"X",
}
assert set(fold_vocab.keys()) == expected_codes
assert len(fold_vocab) == 21
# Check that all values are unique integers
values = list(fold_vocab.values())
assert len(set(values)) == len(values) # All unique
assert all(isinstance(v, int) for v in values)
assert min(values) == 0
assert max(values) == 20
def test_get_foldseek_onehot_valid_sequence(self):
"""Test one-hot encoding with a valid sequence in fold_record."""
n0 = "test_protein"
fold_sequence = "DPVQA"
size_n0 = len(fold_sequence)
fold_record = {n0: fold_sequence}
result = get_foldseek_onehot(n0, size_n0, fold_record, fold_vocab)
# Check dimensions
assert result.shape == (size_n0, len(fold_vocab))
assert result.dtype == torch.float32
# Check that each position has exactly one 1 and the rest are 0s
for i in range(size_n0):
assert torch.sum(result[i]) == 1.0
# Check that the correct position is set to 1
amino_acid = fold_sequence[i]
expected_idx = fold_vocab[amino_acid]
assert result[i, expected_idx] == 1.0
def test_get_foldseek_onehot_protein_not_in_record(self):
"""Test one-hot encoding when protein is not in fold_record."""
n0 = "missing_protein"
size_n0 = 5
fold_record = {"other_protein": "DPVQA"}
result = get_foldseek_onehot(n0, size_n0, fold_record, fold_vocab)
# Should return all zeros
assert result.shape == (size_n0, len(fold_vocab))
assert result.dtype == torch.float32
assert torch.all(result == 0.0)
def test_get_foldseek_onehot_size_mismatch(self):
"""Test assertion error when size doesn't match sequence length."""
n0 = "test_protein"
fold_sequence = "DPVQA"
size_n0 = 3 # Different from actual sequence length (5)
fold_record = {n0: fold_sequence}
try:
get_foldseek_onehot(n0, size_n0, fold_record, fold_vocab)
assert False, "Should have raised AssertionError"
except AssertionError:
pass # Expected behavior
def test_get_foldseek_onehot_invalid_amino_acid(self):
"""Test assertion error with invalid amino acid in sequence."""
n0 = "test_protein"
fold_sequence = "DPVQZ" # Z is not in fold_vocab
size_n0 = len(fold_sequence)
fold_record = {n0: fold_sequence}
try:
get_foldseek_onehot(n0, size_n0, fold_record, fold_vocab)
assert False, "Should have raised AssertionError"
except AssertionError:
pass # Expected behavior
def test_get_foldseek_onehot_empty_sequence(self):
"""Test one-hot encoding with empty sequence."""
n0 = "empty_protein"
fold_sequence = ""
size_n0 = 0
fold_record = {n0: fold_sequence}
result = get_foldseek_onehot(n0, size_n0, fold_record, fold_vocab)
assert result.shape == (0, len(fold_vocab))
assert result.dtype == torch.float32
def test_get_foldseek_onehot_all_amino_acids(self):
"""Test one-hot encoding with all possible amino acids."""
# Create a sequence with all amino acids in fold_vocab
all_amino_acids = "".join(sorted(fold_vocab.keys()))
n0 = "all_aa_protein"
size_n0 = len(all_amino_acids)
fold_record = {n0: all_amino_acids}
result = get_foldseek_onehot(n0, size_n0, fold_record, fold_vocab)
assert result.shape == (size_n0, len(fold_vocab))
# Each position should have exactly one 1
for i in range(size_n0):
assert torch.sum(result[i]) == 1.0
# Each amino acid should appear exactly once
for aa, idx in fold_vocab.items():
aa_positions = torch.where(result[:, idx] == 1.0)[0]
assert len(aa_positions) == 1 # Should appear exactly once
def test_get_foldseek_onehot_repeated_amino_acids(self):
"""Test one-hot encoding with repeated amino acids."""
n0 = "repeat_protein"
fold_sequence = "DDDD" # All D amino acids
size_n0 = len(fold_sequence)
fold_record = {n0: fold_sequence}
result = get_foldseek_onehot(n0, size_n0, fold_record, fold_vocab)
assert result.shape == (size_n0, len(fold_vocab))
# All positions should have D (index 0) set to 1
d_idx = fold_vocab["D"]
for i in range(size_n0):
assert result[i, d_idx] == 1.0
assert torch.sum(result[i]) == 1.0
def test_get_3di_sequences_with_cif(self):
"""Test get_3di_sequences using biotite on a CIF file."""
import os
cif_path = os.path.join(os.path.dirname(__file__), "../../data/8JZ0.cif")
cif_path = os.path.abspath(cif_path)
result = get_3di_sequences([cif_path])
assert isinstance(result, dict)
assert "8JZ0" in result
seq_record = result["8JZ0"]
assert isinstance(seq_record, str)
assert len(seq_record) > 0
def test_get_3di_sequences_empty_pdb_list(self):
"""Test get_3di_sequences with empty PDB file list."""
pdb_files = []
try:
result = get_3di_sequences(pdb_files)
# Should handle empty list gracefully
assert isinstance(result, dict)
except Exception as e:
# May fail due to empty input, which is acceptable
logger.warning(f"Unexpected error with empty PDB list: {e}")
pass
def test_get_3di_sequences_nonexistent_pdb_files(self):
"""Test get_3di_sequences with non-existent PDB files."""
pdb_files = ["nonexistent1.pdb", "nonexistent2.pdb"]
try:
result = get_3di_sequences(pdb_files)
# If it succeeds, result should be a dict
assert isinstance(result, dict)
except Exception as e:
# Expected to fail with non-existent files
logger.error(f"Unexpected error with non-existent PDB files: {e}")
pass
def test_fold_vocab_mapping(self):
"""Test specific mappings in fold_vocab."""
# Test a few specific mappings
assert fold_vocab["D"] == 0
assert fold_vocab["P"] == 1
assert fold_vocab["V"] == 2
assert fold_vocab["X"] == 20 # Should be the last one
def test_get_foldseek_onehot_tensor_properties(self):
"""Test tensor properties of the output."""
n0 = "test_protein"
fold_sequence = "DPVQA"
size_n0 = len(fold_sequence)
fold_record = {n0: fold_sequence}
result = get_foldseek_onehot(n0, size_n0, fold_record, fold_vocab)
# Test tensor properties
assert isinstance(result, torch.Tensor)
assert result.device == torch.device("cpu") # Default device
assert not result.requires_grad # Should not require gradients by default
assert result.dtype == torch.float32
# Test value range
assert torch.all(result >= 0.0)
assert torch.all(result <= 1.0)
assert torch.all((result == 0.0) | (result == 1.0)) # Only 0s and 1s