Files
D-SCRIPT/dscript/utils.py
Samuel Sledzieski 0b1d9e5007 fix failing tests
2025-12-04 15:13:52 -05:00

173 lines
5.1 KiB
Python

import sys
import numpy as np
import torch
import torch.utils.data
from loguru import logger
from .loading import LoadingPool
def setup_logger(log_file=None, also_stdout=False):
"""
Setup loguru logger for D-SCRIPT.
:param log_file: File handle or path to write logs to
:type log_file: file handle, str, or None
:param also_stdout: Whether to also log to stdout
:type also_stdout: bool
"""
# Remove default logger
logger.remove()
# Add file handler if log_file is provided
if log_file is not None:
logger.add(log_file)
# Add stdout handler if requested or if no file specified
if also_stdout or log_file is None:
logger.add(sys.stdout)
def log(m, file=None, timestamped=True, print_also=False):
"""
Legacy log function that wraps loguru for backward compatibility.
:param m: Message to log
:type m: str
:param file: File handle to write to (if None, uses stdout)
:type file: file handle or None
:param timestamped: Whether to include timestamp (handled by loguru)
:type timestamped: bool
:param print_also: Whether to also print to stdout when writing to file
:type print_also: bool
"""
# Configure logger based on parameters
setup_logger(log_file=file, also_stdout=print_also)
# Log the message
logger.info(m)
# Flush the file if it's provided and has flush method
if file is not None and hasattr(file, "flush"):
file.flush()
def RBF(D, sigma=None, pseudocount=1e-10):
"""
Convert distance matrix into similarity matrix using Radial Basis Function (RBF) Kernel.
:math:`RBF(x,x') = \\exp{\\frac{-(x - x')^{2}}{2\\sigma^{2}}}`
:param D: Distance matrix
:type D: np.ndarray
:param sigma: Bandwith of RBF Kernel [default: :math:`\\sqrt{\\text{max}(D)}`]
:type sigma: float
:return: Similarity matrix
:rtype: np.ndarray
"""
D += pseudocount
sigma = sigma or np.sqrt(np.max(D))
return np.exp(-1 * (np.square(D) / (2 * sigma**2)))
# If keys is a dict (of key -> index) will produce a list of indices instead of a dict
# Now replaced by loading.LoadingPool; this is a wrapper for existing behavior
def load_hdf5_parallel(file_path, keys, n_jobs=-1, return_dict=True):
"""
Load keys from hdf5 file into memory
:param file_path: Path to hdf5 file
:type file_path: str
:param keys: List of keys to get
:type keys: iterable[str]
:return: if return_dict, a mapping of keys (proteins names) to pointers to empbeddings.
otherwise, a list of pointers in the same order as keys
:rtype: list
"""
pool = LoadingPool(file_path, n_jobs)
result = pool.load_once(keys)
if return_dict:
return dict(zip(keys, result))
return result
# Parse device argument
def parse_device(device_arg, logFile):
if device_arg.lower() == "cpu":
device = "cpu"
use_cuda = False
elif device_arg.lower() == "all":
device = -1 # Use all GPUs
use_cuda = True
elif device_arg.isdigit(): # Allow only nonnegative integers
device = int(device_arg)
use_cuda = True
else:
log(
f"Invalid device argument: {device_arg}. Use 'cpu', 'all', or a GPU index.",
file=logFile,
print_also=True,
)
logFile.close()
sys.exit(1)
# Validate CUDA availability and device index if GPU requested
if use_cuda:
if not torch.cuda.is_available():
log(
"CUDA not available but GPU requested. Use --device cpu for CPU execution.",
file=logFile,
print_also=True,
)
logFile.close()
sys.exit(1)
if device >= 0 and device >= torch.cuda.device_count():
log(
f"Invalid device argument: {device_arg} exceeds the number of GPUs available, which is {torch.cuda.device_count()}. Please specify a valid GPU, or use --device cpu for CPU execution.",
file=logFile,
print_also=True,
)
return device
class PairedDataset(torch.utils.data.Dataset):
"""
Dataset to be used by the PyTorch data loader for pairs of sequences and their labels.
:param X0: List of first item in the pair
:param X1: List of second item in the pair
:param Y: List of labels
"""
def __init__(self, X0, X1, Y):
self.X0 = X0
self.X1 = X1
self.Y = Y
assert len(X0) == len(X1), (
"X0: " + str(len(X0)) + " X1: " + str(len(X1)) + " Y: " + str(len(Y))
)
assert len(X0) == len(Y), (
"X0: " + str(len(X0)) + " X1: " + str(len(X1)) + " Y: " + str(len(Y))
)
def __len__(self):
return len(self.X0)
def __getitem__(self, i):
return self.X0[i], self.X1[i], self.Y[i]
def collate_paired_sequences(args):
"""
Collate function for PyTorch data loader.
"""
if not len(args):
return [], [], torch.tensor([])
x0 = [a[0] for a in args]
x1 = [a[1] for a in args]
y = torch.stack([a[2] for a in args], 0)
return x0, x1, y