mirror of
https://github.com/samsledje/D-SCRIPT.git
synced 2026-06-04 15:04:24 +08:00
fix failing tests
This commit is contained in:
20
.github/workflows/autorun-tests.yml
vendored
20
.github/workflows/autorun-tests.yml
vendored
@@ -15,18 +15,24 @@ jobs:
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python 3.10
|
||||
|
||||
- name: Set up Python 3.11
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: "3.10"
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -e ".[test,dev]"
|
||||
uv sync --extra test --extra dev && uv lock
|
||||
|
||||
- name: Lint and format with ruff
|
||||
run: |
|
||||
ruff check . --statistics
|
||||
ruff format .
|
||||
uv run ruff check . --statistics
|
||||
uv run ruff format .
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
pytest --cov=dscript --cov-report=xml --cov-report=term-missing
|
||||
uv run pytest --cov-report=xml --cov-report=term-missing
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
__version__ = "0.3.1"
|
||||
from importlib.metadata import version as _get_version
|
||||
|
||||
__version__ = _get_version("dscript")
|
||||
__citation__ = """Sledzieski, Singh, Cowen, Berger. "D-SCRIPT translates genome to phenome with sequence-based, structure-aware, genome-scale predictions of protein-protein interactions." Cell Systems 12, no. 10 (2021): 969-982.
|
||||
|
||||
Devkota, Singh, Sledzieski, Berger, Cowen, Topsy-Turvy: integrating a global view into sequence-based PPI prediction, Bioinformatics, In Press."""
|
||||
|
||||
@@ -28,15 +28,15 @@ class TestLanguageModelUnit:
|
||||
|
||||
# Mock the proj layer
|
||||
model.proj = Mock()
|
||||
model.proj.weight = torch.randn(100, 100)
|
||||
model.proj.bias = torch.zeros(100)
|
||||
model.proj.weight = torch.randn(6165, 6165)
|
||||
model.proj.bias = torch.zeros(6165)
|
||||
|
||||
# Mock transform to return realistic embeddings
|
||||
def mock_transform(x):
|
||||
batch_size = x.shape[0]
|
||||
seq_len = x.shape[1]
|
||||
# Return (batch, seq_len, embedding_dim=100)
|
||||
return torch.randn(batch_size, seq_len, 100)
|
||||
# Return (batch, seq_len, embedding_dim=6165)
|
||||
return torch.randn(batch_size, seq_len, 6165)
|
||||
|
||||
model.transform = Mock(side_effect=mock_transform)
|
||||
return model
|
||||
@@ -49,10 +49,10 @@ class TestLanguageModelUnit:
|
||||
test_seq = "MKTAYIAKQRQISFVKSHFSRQ"
|
||||
x = lm_embed(test_seq, use_cuda=False)
|
||||
|
||||
# Should be (batch=1, seq_len, embedding_dim=100)
|
||||
# Should be (batch=1, seq_len, embedding_dim=6165)
|
||||
assert x.shape[0] == 1
|
||||
assert x.shape[1] == len(test_seq)
|
||||
assert x.shape[2] == 100
|
||||
assert x.shape[2] == 6165
|
||||
|
||||
@patch("dscript.language_model.get_pretrained")
|
||||
def test_lm_embed_returns_tensor(self, mock_get_pretrained, mock_model):
|
||||
@@ -73,7 +73,7 @@ class TestLanguageModelUnit:
|
||||
x = lm_embed(short_seq, use_cuda=False)
|
||||
|
||||
assert x.shape[1] == 2
|
||||
assert x.shape[2] == 100
|
||||
assert x.shape[2] == 6165
|
||||
|
||||
@patch("dscript.language_model.get_pretrained")
|
||||
def test_lm_embed_single_amino_acid(self, mock_get_pretrained, mock_model):
|
||||
@@ -84,7 +84,7 @@ class TestLanguageModelUnit:
|
||||
x = lm_embed(single_aa, use_cuda=False)
|
||||
|
||||
assert x.shape[1] == 1
|
||||
assert x.shape[2] == 100
|
||||
assert x.shape[2] == 6165
|
||||
|
||||
@patch("dscript.language_model.get_pretrained")
|
||||
def test_embed_from_fasta_creates_h5(self, mock_get_pretrained, mock_model, tmp_path):
|
||||
@@ -215,7 +215,7 @@ class TestLanguageModelIntegration:
|
||||
|
||||
assert x.shape[0] == 1
|
||||
assert x.shape[1] == len(test_seq)
|
||||
assert x.shape[2] == 100
|
||||
assert x.shape[2] == 6165
|
||||
assert isinstance(x, torch.Tensor)
|
||||
|
||||
def test_embed_from_fasta_real(self, tmp_path):
|
||||
@@ -239,4 +239,4 @@ class TestLanguageModelIntegration:
|
||||
embedding = f[name][:]
|
||||
assert embedding.shape[0] == 1
|
||||
assert embedding.shape[1] == len(seq)
|
||||
assert embedding.shape[2] == 100
|
||||
assert embedding.shape[2] == 6165
|
||||
|
||||
@@ -89,7 +89,7 @@ class TestMainFunction:
|
||||
"test.tsv",
|
||||
"--embedding",
|
||||
"embed.h5",
|
||||
"--output",
|
||||
"--outfile",
|
||||
"output",
|
||||
"--save-prefix",
|
||||
"model",
|
||||
@@ -162,13 +162,13 @@ class TestMainFunction:
|
||||
test_args = [
|
||||
"dscript",
|
||||
"predict_bipartite",
|
||||
"--seqs0",
|
||||
"seqs0.fasta",
|
||||
"--seqs1",
|
||||
"seqs1.fasta",
|
||||
"--embeddings0",
|
||||
"--protA",
|
||||
"protA.txt",
|
||||
"--protB",
|
||||
"protB.txt",
|
||||
"--embedA",
|
||||
"embed0.h5",
|
||||
"--embeddings1",
|
||||
"--embedB",
|
||||
"embed1.h5",
|
||||
"--model",
|
||||
"model_path",
|
||||
@@ -188,7 +188,7 @@ class TestMainFunction:
|
||||
test_args = [
|
||||
"dscript",
|
||||
"evaluate",
|
||||
"--pairs",
|
||||
"--test",
|
||||
"pairs.tsv",
|
||||
"--embeddings",
|
||||
"embed.h5",
|
||||
@@ -261,7 +261,7 @@ class TestMainFunction:
|
||||
"test.tsv",
|
||||
"--embedding",
|
||||
"embed.h5",
|
||||
"--output",
|
||||
"--outfile",
|
||||
"output",
|
||||
"--save-prefix",
|
||||
"model",
|
||||
|
||||
@@ -95,7 +95,7 @@ class TestLog:
|
||||
mock_file.flush = Mock()
|
||||
|
||||
log("flush test", file=mock_file)
|
||||
mock_file.flush.assert_called_once()
|
||||
mock_file.flush.assert_called()
|
||||
|
||||
|
||||
class TestRBF:
|
||||
|
||||
@@ -53,7 +53,7 @@ def log(m, file=None, timestamped=True, print_also=False):
|
||||
file.flush()
|
||||
|
||||
|
||||
def RBF(D, sigma=None):
|
||||
def RBF(D, sigma=None, pseudocount=1e-10):
|
||||
"""
|
||||
Convert distance matrix into similarity matrix using Radial Basis Function (RBF) Kernel.
|
||||
|
||||
@@ -66,6 +66,7 @@ def RBF(D, sigma=None):
|
||||
:return: Similarity matrix
|
||||
:rtype: np.ndarray
|
||||
"""
|
||||
D += pseudocount
|
||||
sigma = sigma or np.sqrt(np.max(D))
|
||||
return np.exp(-1 * (np.square(D) / (2 * sigma**2)))
|
||||
|
||||
@@ -161,7 +162,11 @@ def collate_paired_sequences(args):
|
||||
"""
|
||||
Collate function for PyTorch data loader.
|
||||
"""
|
||||
if not len(args):
|
||||
return [], [], torch.tensor([])
|
||||
|
||||
x0 = [a[0] for a in args]
|
||||
x1 = [a[1] for a in args]
|
||||
y = [a[2] for a in args]
|
||||
return x0, x1, torch.stack(y, 0)
|
||||
y = torch.stack([a[2] for a in args], 0)
|
||||
|
||||
return x0, x1, y
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=45", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
requires = ["uv_build>=0.9.15,<0.10.0"]
|
||||
build-backend = "uv_build"
|
||||
|
||||
[project]
|
||||
name = "dscript"
|
||||
dynamic = ["version"]
|
||||
version = "0.3.1"
|
||||
description = "D-SCRIPT: protein-protein interaction prediction"
|
||||
authors = [
|
||||
{name = "Samuel Sledzieski", email = "samsl@mit.edu"}
|
||||
@@ -43,6 +43,7 @@ docs = [
|
||||
"jinja2<3.1",
|
||||
"Sphinx==3.4",
|
||||
"sphinx-rtd-theme==1.0.0",
|
||||
"pydata-sphinx-theme",
|
||||
"sphinxcontrib-applehelp==1.0.2",
|
||||
"sphinxcontrib-devhelp==1.0.2",
|
||||
"sphinxcontrib-htmlhelp==1.0.3",
|
||||
@@ -57,11 +58,14 @@ Homepage = "http://dscript.csail.mit.edu"
|
||||
[project.scripts]
|
||||
dscript = "dscript.__main__:main"
|
||||
|
||||
[tool.setuptools.dynamic]
|
||||
version = {attr = "dscript.__version__"}
|
||||
[tool.uv]
|
||||
package = true
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["dscript*"]
|
||||
[tool.uv.sources]
|
||||
|
||||
[tool.uv.build-backend]
|
||||
module-name = "dscript"
|
||||
module-root = ""
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 90
|
||||
@@ -89,6 +93,9 @@ known-first-party = ["dscript"]
|
||||
[tool.coverage.paths]
|
||||
source = ["dscript"]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = ["dscript/tests/*"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
filterwarnings = [
|
||||
"ignore::UserWarning",
|
||||
|
||||
Reference in New Issue
Block a user