fix failing tests

This commit is contained in:
Samuel Sledzieski
2025-12-04 15:13:52 -05:00
parent 09c2bea9b2
commit 0b1d9e5007
9 changed files with 2520 additions and 38 deletions

View File

@@ -15,18 +15,24 @@ jobs:
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.10
- name: Set up Python 3.11
uses: actions/setup-python@v3
with:
python-version: "3.10"
python-version: "3.11"
- name: Install uv
uses: astral-sh/setup-uv@v7
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e ".[test,dev]"
uv sync --extra test --extra dev && uv lock
- name: Lint and format with ruff
run: |
ruff check . --statistics
ruff format .
uv run ruff check . --statistics
uv run ruff format .
- name: Test with pytest
run: |
pytest --cov=dscript --cov-report=xml --cov-report=term-missing
uv run pytest --cov-report=xml --cov-report=term-missing

View File

@@ -1,4 +1,6 @@
__version__ = "0.3.1"
from importlib.metadata import version as _get_version
__version__ = _get_version("dscript")
__citation__ = """Sledzieski, Singh, Cowen, Berger. "D-SCRIPT translates genome to phenome with sequence-based, structure-aware, genome-scale predictions of protein-protein interactions." Cell Systems 12, no. 10 (2021): 969-982.
Devkota, Singh, Sledzieski, Berger, Cowen, Topsy-Turvy: integrating a global view into sequence-based PPI prediction, Bioinformatics, In Press."""

View File

@@ -28,15 +28,15 @@ class TestLanguageModelUnit:
# Mock the proj layer
model.proj = Mock()
model.proj.weight = torch.randn(100, 100)
model.proj.bias = torch.zeros(100)
model.proj.weight = torch.randn(6165, 6165)
model.proj.bias = torch.zeros(6165)
# Mock transform to return realistic embeddings
def mock_transform(x):
batch_size = x.shape[0]
seq_len = x.shape[1]
# Return (batch, seq_len, embedding_dim=100)
return torch.randn(batch_size, seq_len, 100)
# Return (batch, seq_len, embedding_dim=6165)
return torch.randn(batch_size, seq_len, 6165)
model.transform = Mock(side_effect=mock_transform)
return model
@@ -49,10 +49,10 @@ class TestLanguageModelUnit:
test_seq = "MKTAYIAKQRQISFVKSHFSRQ"
x = lm_embed(test_seq, use_cuda=False)
# Should be (batch=1, seq_len, embedding_dim=100)
# Should be (batch=1, seq_len, embedding_dim=6165)
assert x.shape[0] == 1
assert x.shape[1] == len(test_seq)
assert x.shape[2] == 100
assert x.shape[2] == 6165
@patch("dscript.language_model.get_pretrained")
def test_lm_embed_returns_tensor(self, mock_get_pretrained, mock_model):
@@ -73,7 +73,7 @@ class TestLanguageModelUnit:
x = lm_embed(short_seq, use_cuda=False)
assert x.shape[1] == 2
assert x.shape[2] == 100
assert x.shape[2] == 6165
@patch("dscript.language_model.get_pretrained")
def test_lm_embed_single_amino_acid(self, mock_get_pretrained, mock_model):
@@ -84,7 +84,7 @@ class TestLanguageModelUnit:
x = lm_embed(single_aa, use_cuda=False)
assert x.shape[1] == 1
assert x.shape[2] == 100
assert x.shape[2] == 6165
@patch("dscript.language_model.get_pretrained")
def test_embed_from_fasta_creates_h5(self, mock_get_pretrained, mock_model, tmp_path):
@@ -215,7 +215,7 @@ class TestLanguageModelIntegration:
assert x.shape[0] == 1
assert x.shape[1] == len(test_seq)
assert x.shape[2] == 100
assert x.shape[2] == 6165
assert isinstance(x, torch.Tensor)
def test_embed_from_fasta_real(self, tmp_path):
@@ -239,4 +239,4 @@ class TestLanguageModelIntegration:
embedding = f[name][:]
assert embedding.shape[0] == 1
assert embedding.shape[1] == len(seq)
assert embedding.shape[2] == 100
assert embedding.shape[2] == 6165

View File

@@ -89,7 +89,7 @@ class TestMainFunction:
"test.tsv",
"--embedding",
"embed.h5",
"--output",
"--outfile",
"output",
"--save-prefix",
"model",
@@ -162,13 +162,13 @@ class TestMainFunction:
test_args = [
"dscript",
"predict_bipartite",
"--seqs0",
"seqs0.fasta",
"--seqs1",
"seqs1.fasta",
"--embeddings0",
"--protA",
"protA.txt",
"--protB",
"protB.txt",
"--embedA",
"embed0.h5",
"--embeddings1",
"--embedB",
"embed1.h5",
"--model",
"model_path",
@@ -188,7 +188,7 @@ class TestMainFunction:
test_args = [
"dscript",
"evaluate",
"--pairs",
"--test",
"pairs.tsv",
"--embeddings",
"embed.h5",
@@ -261,7 +261,7 @@ class TestMainFunction:
"test.tsv",
"--embedding",
"embed.h5",
"--output",
"--outfile",
"output",
"--save-prefix",
"model",

View File

@@ -95,7 +95,7 @@ class TestLog:
mock_file.flush = Mock()
log("flush test", file=mock_file)
mock_file.flush.assert_called_once()
mock_file.flush.assert_called()
class TestRBF:

View File

@@ -53,7 +53,7 @@ def log(m, file=None, timestamped=True, print_also=False):
file.flush()
def RBF(D, sigma=None):
def RBF(D, sigma=None, pseudocount=1e-10):
"""
Convert distance matrix into similarity matrix using Radial Basis Function (RBF) Kernel.
@@ -66,6 +66,7 @@ def RBF(D, sigma=None):
:return: Similarity matrix
:rtype: np.ndarray
"""
D += pseudocount
sigma = sigma or np.sqrt(np.max(D))
return np.exp(-1 * (np.square(D) / (2 * sigma**2)))
@@ -161,7 +162,11 @@ def collate_paired_sequences(args):
"""
Collate function for PyTorch data loader.
"""
if not len(args):
return [], [], torch.tensor([])
x0 = [a[0] for a in args]
x1 = [a[1] for a in args]
y = [a[2] for a in args]
return x0, x1, torch.stack(y, 0)
y = torch.stack([a[2] for a in args], 0)
return x0, x1, y

View File

@@ -1,10 +1,10 @@
[build-system]
requires = ["setuptools>=45", "wheel"]
build-backend = "setuptools.build_meta"
requires = ["uv_build>=0.9.15,<0.10.0"]
build-backend = "uv_build"
[project]
name = "dscript"
dynamic = ["version"]
version = "0.3.1"
description = "D-SCRIPT: protein-protein interaction prediction"
authors = [
{name = "Samuel Sledzieski", email = "samsl@mit.edu"}
@@ -43,6 +43,7 @@ docs = [
"jinja2<3.1",
"Sphinx==3.4",
"sphinx-rtd-theme==1.0.0",
"pydata-sphinx-theme",
"sphinxcontrib-applehelp==1.0.2",
"sphinxcontrib-devhelp==1.0.2",
"sphinxcontrib-htmlhelp==1.0.3",
@@ -57,11 +58,14 @@ Homepage = "http://dscript.csail.mit.edu"
[project.scripts]
dscript = "dscript.__main__:main"
[tool.setuptools.dynamic]
version = {attr = "dscript.__version__"}
[tool.uv]
package = true
[tool.setuptools.packages.find]
include = ["dscript*"]
[tool.uv.sources]
[tool.uv.build-backend]
module-name = "dscript"
module-root = ""
[tool.ruff]
line-length = 90
@@ -89,6 +93,9 @@ known-first-party = ["dscript"]
[tool.coverage.paths]
source = ["dscript"]
[tool.coverage.run]
omit = ["dscript/tests/*"]
[tool.pytest.ini_options]
filterwarnings = [
"ignore::UserWarning",

2462
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff