mirror of
https://github.com/samsledje/D-SCRIPT.git
synced 2026-06-04 15:04:24 +08:00
* Expand test coverage with comprehensive test suites Add extensive test coverage for previously untested modules: - test_utils.py: Comprehensive tests for utility functions (setup_logger, log, RBF, parse_device, load_hdf5_parallel, PairedDataset, collate_paired_sequences) - test_glider.py: Complete test suite for graph-based link prediction module (get_dim, densify, compute_X_normalized, scoring functions, GLIDE algorithms) - test_loading.py: Tests for parallel HDF5 data loading with LoadingPool, including edge cases, error handling, and integration tests - test_language_model.py: Expanded from 2 to 13 test methods, adding coverage for lm_embed, embed_from_fasta with various edge cases and validations These additions significantly improve test coverage for: - dscript/utils.py (167 lines, previously untested) - dscript/glider.py (346 lines, previously untested) - dscript/loading.py (92 lines, previously untested) - dscript/language_model.py (minimal coverage expanded) Total new test methods: ~200+ assertions across 4 test modules * Add comprehensive tests for command modules and worker functions Create four new test modules to expand coverage of previously untested code: 1. test_extract_3di.py (19 test methods, ~370 lines) - Tests for 3Di sequence extraction from PDB/CIF files - Argument parsing, file filtering, FASTA output validation - Integration tests for full workflow - Covers dscript/commands/extract_3di.py (~58 lines) 2. test_par_writer.py (24 test methods, ~400 lines) - Tests for parallel prediction writer process - TSV output writing, threshold filtering, contact map storage - HDF5 contact map dataset handling - Progress tracking and data type validation - Covers dscript/commands/par_writer.py (~40 lines) 3. test_main.py (24 test methods, ~320 lines) - Tests for CLI entry point and argument parsing - CitationAction class testing - All subcommand registration and invocation - Version and help flag handling - Integration tests for command dispatch - Covers dscript/__main__.py (~87 lines, increasing from ~85% to ~95%) 4. test_load_worker.py (23 test methods, ~330 lines) - Direct unit tests for HDF5 loading worker function - Queue handling, data type conversion, memory sharing - Error handling for corrupted/missing files - Multi-dimensional array support - Covers dscript/load_worker.py (~25 lines, previously only indirect coverage) Total additions: - ~1,420 lines of new test code - 90+ test methods with comprehensive assertions - ~210 lines of source code now directly tested - Addresses high-priority gaps identified in coverage analysis These tests complement the existing suite and focus on command-line interface components and parallel processing infrastructure. * Fix linting issues and apply code formatting - Remove unused variables flagged by ruff - Apply ruff formatting to all test files - Ensure all pre-commit hooks pass Changes: - test_loading.py: Remove unused 'f' variable - test_main.py: Remove unused 'fake_out' and 'output' variables - test_utils.py: Remove unused 'log_file' variable and tmp_path param - Applied ruff formatting to maintain code style consistency * Fix test_load_worker.py hanging issue in CI Rewrote test_load_worker.py to prevent CI hangs that occurred when tests called the blocking worker function directly. The worker function _hdf5_load_partial_func runs in an infinite loop waiting on a queue, which caused tests to hang indefinitely. Changes: - Created run_worker_with_timeout() helper that wraps worker execution in a daemon thread with configurable timeout (default 5 seconds) - Modified all tests to use this helper and assert successful completion - Changed queue operations from blocking get() to non-blocking get_nowait() - Reduced test count from 23 to 16 focused tests - Added documentation noting worker is primarily tested via LoadingPool This should resolve the CI timeout issue where tests hung at 43% completion. * Rewrite test_language_model.py to use mocks instead of real model The original tests were calling the real language model which: - Downloads/loads pretrained model weights (slow, can fail) - Runs actual neural network inference (resource intensive) - Causes test failures when model files aren't available Changes: - Rewrote unit tests to mock get_pretrained() function - Mock model returns realistic tensor shapes but doesn't load weights - Tests are now fast, reliable, and don't require model files - Moved real model tests to TestLanguageModelIntegration class - Marked integration tests with @pytest.mark.slow so they can be skipped - Removed unnecessary loguru import that caused import errors - Removed problematic setup.py install step from setup_class This should fix the 4 failing tests reported by CI. * fix failing tests * Update .github/workflows/autorun-tests.yml Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update .github/workflows/autorun-tests.yml Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
173 lines
5.1 KiB
Python
173 lines
5.1 KiB
Python
import sys
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.utils.data
|
|
from loguru import logger
|
|
|
|
from .loading import LoadingPool
|
|
|
|
|
|
def setup_logger(log_file=None, also_stdout=False):
|
|
"""
|
|
Setup loguru logger for D-SCRIPT.
|
|
|
|
:param log_file: File handle or path to write logs to
|
|
:type log_file: file handle, str, or None
|
|
:param also_stdout: Whether to also log to stdout
|
|
:type also_stdout: bool
|
|
"""
|
|
# Remove default logger
|
|
logger.remove()
|
|
|
|
# Add file handler if log_file is provided
|
|
if log_file is not None:
|
|
logger.add(log_file)
|
|
|
|
# Add stdout handler if requested or if no file specified
|
|
if also_stdout or log_file is None:
|
|
logger.add(sys.stdout)
|
|
|
|
|
|
def log(m, file=None, timestamped=True, print_also=False):
|
|
"""
|
|
Legacy log function that wraps loguru for backward compatibility.
|
|
|
|
:param m: Message to log
|
|
:type m: str
|
|
:param file: File handle to write to (if None, uses stdout)
|
|
:type file: file handle or None
|
|
:param timestamped: Whether to include timestamp (handled by loguru)
|
|
:type timestamped: bool
|
|
:param print_also: Whether to also print to stdout when writing to file
|
|
:type print_also: bool
|
|
"""
|
|
# Configure logger based on parameters
|
|
setup_logger(log_file=file, also_stdout=print_also)
|
|
|
|
# Log the message
|
|
logger.info(m)
|
|
|
|
# Flush the file if it's provided and has flush method
|
|
if file is not None and hasattr(file, "flush"):
|
|
file.flush()
|
|
|
|
|
|
def RBF(D, sigma=None, pseudocount=1e-10):
|
|
"""
|
|
Convert distance matrix into similarity matrix using Radial Basis Function (RBF) Kernel.
|
|
|
|
:math:`RBF(x,x') = \\exp{\\frac{-(x - x')^{2}}{2\\sigma^{2}}}`
|
|
|
|
:param D: Distance matrix
|
|
:type D: np.ndarray
|
|
:param sigma: Bandwith of RBF Kernel [default: :math:`\\sqrt{\\text{max}(D)}`]
|
|
:type sigma: float
|
|
:return: Similarity matrix
|
|
:rtype: np.ndarray
|
|
"""
|
|
D += pseudocount
|
|
sigma = sigma or np.sqrt(np.max(D))
|
|
return np.exp(-1 * (np.square(D) / (2 * sigma**2)))
|
|
|
|
|
|
# If keys is a dict (of key -> index) will produce a list of indices instead of a dict
|
|
# Now replaced by loading.LoadingPool; this is a wrapper for existing behavior
|
|
def load_hdf5_parallel(file_path, keys, n_jobs=-1, return_dict=True):
|
|
"""
|
|
Load keys from hdf5 file into memory
|
|
|
|
:param file_path: Path to hdf5 file
|
|
:type file_path: str
|
|
:param keys: List of keys to get
|
|
:type keys: iterable[str]
|
|
:return: if return_dict, a mapping of keys (proteins names) to pointers to empbeddings.
|
|
otherwise, a list of pointers in the same order as keys
|
|
:rtype: list
|
|
"""
|
|
|
|
pool = LoadingPool(file_path, n_jobs)
|
|
result = pool.load_once(keys)
|
|
if return_dict:
|
|
return dict(zip(keys, result))
|
|
return result
|
|
|
|
|
|
# Parse device argument
|
|
def parse_device(device_arg, logFile):
|
|
if device_arg.lower() == "cpu":
|
|
device = "cpu"
|
|
use_cuda = False
|
|
elif device_arg.lower() == "all":
|
|
device = -1 # Use all GPUs
|
|
use_cuda = True
|
|
elif device_arg.isdigit(): # Allow only nonnegative integers
|
|
device = int(device_arg)
|
|
use_cuda = True
|
|
else:
|
|
log(
|
|
f"Invalid device argument: {device_arg}. Use 'cpu', 'all', or a GPU index.",
|
|
file=logFile,
|
|
print_also=True,
|
|
)
|
|
logFile.close()
|
|
sys.exit(1)
|
|
# Validate CUDA availability and device index if GPU requested
|
|
if use_cuda:
|
|
if not torch.cuda.is_available():
|
|
log(
|
|
"CUDA not available but GPU requested. Use --device cpu for CPU execution.",
|
|
file=logFile,
|
|
print_also=True,
|
|
)
|
|
logFile.close()
|
|
sys.exit(1)
|
|
if device >= 0 and device >= torch.cuda.device_count():
|
|
log(
|
|
f"Invalid device argument: {device_arg} exceeds the number of GPUs available, which is {torch.cuda.device_count()}. Please specify a valid GPU, or use --device cpu for CPU execution.",
|
|
file=logFile,
|
|
print_also=True,
|
|
)
|
|
return device
|
|
|
|
|
|
class PairedDataset(torch.utils.data.Dataset):
|
|
"""
|
|
Dataset to be used by the PyTorch data loader for pairs of sequences and their labels.
|
|
|
|
:param X0: List of first item in the pair
|
|
:param X1: List of second item in the pair
|
|
:param Y: List of labels
|
|
"""
|
|
|
|
def __init__(self, X0, X1, Y):
|
|
self.X0 = X0
|
|
self.X1 = X1
|
|
self.Y = Y
|
|
assert len(X0) == len(X1), (
|
|
"X0: " + str(len(X0)) + " X1: " + str(len(X1)) + " Y: " + str(len(Y))
|
|
)
|
|
assert len(X0) == len(Y), (
|
|
"X0: " + str(len(X0)) + " X1: " + str(len(X1)) + " Y: " + str(len(Y))
|
|
)
|
|
|
|
def __len__(self):
|
|
return len(self.X0)
|
|
|
|
def __getitem__(self, i):
|
|
return self.X0[i], self.X1[i], self.Y[i]
|
|
|
|
|
|
def collate_paired_sequences(args):
|
|
"""
|
|
Collate function for PyTorch data loader.
|
|
"""
|
|
if not len(args):
|
|
return [], [], torch.tensor([])
|
|
|
|
x0 = [a[0] for a in args]
|
|
x1 = [a[1] for a in args]
|
|
y = torch.stack([a[2] for a in args], 0)
|
|
|
|
return x0, x1, y
|