update language model to fix issue #35

This commit is contained in:
samsledje
2022-08-18 12:10:32 -04:00
parent b7834b6caa
commit 988c1b35ff
4 changed files with 49 additions and 8 deletions

View File

@@ -7,10 +7,15 @@
## v0.2
### v0.2.1
### v0.2.2
- Resolve #35 to use `require_dataset` -- can now add multiple .fasta files to the same h5 file
- Update pretrained API and docs to include Topsy-Turvy
- Add retry decorator to get_pretrained if download fails
### v0.2.1: 2022-06-28 -- Bug fixes
- Add biopython to setup.py
### v0.2.0
### v0.2.0: 2022-06-24 -- Integration of Topsy Turvy
- Integrate Topsy-Turvy to allow for top-down supervision
- Use utils.log function across all commands

View File

@@ -94,7 +94,7 @@ def embed_from_fasta(fastaPath, outputPath, device=0, verbose=False):
)
)
h5fi = h5py.File(outputPath, "w")
h5fi = h5py.File(outputPath, "a")
log("# Storing to {}...".format(outputPath))
with torch.no_grad():
@@ -103,9 +103,13 @@ def embed_from_fasta(fastaPath, outputPath, device=0, verbose=False):
if name not in h5fi:
x = x.long().unsqueeze(0)
z = model.transform(x)
h5fi.create_dataset(
name, data=z.cpu().numpy(), compression="lzf"
dset = h5fi.require_dataset(
name,
shape=z.shape,
dtype="float32",
compression="lzf",
)
dset[:] = z.cpu().numpy()
except KeyboardInterrupt:
h5fi.close()
sys.exit(1)

View File

@@ -0,0 +1,35 @@
import os
import shutil
import subprocess as sp
from Bio import SeqIO
from dscript.language_model import (
lm_embed,
embed_from_fasta,
)
class TestLanguageModel:
@classmethod
def setup_class(cls):
cmd = "python setup.py install"
proc = sp.Popen(cmd.split())
proc.wait()
os.makedirs("./tmp-dscript-testing/", exist_ok=True)
@classmethod
def teardown_class(cls):
shutil.rmtree("./tmp-dscript-testing/")
def test_lm_embed(self):
seqs = list(SeqIO.parse("dscript/tests/test.fasta", "fasta"))
for seqrec in seqs:
x = lm_embed(str(seqrec.seq))
assert x.shape[1] == len(seqrec.seq)
def embed_from_fasta(self):
embed_from_fasta(
"dscript/tests/test.fasta",
"tmp-dscript-testing/test_embed.h5",
verbose=True,
)

View File

@@ -1,4 +1,3 @@
import dscript
from pathlib import Path
from dscript.pretrained import (
@@ -14,8 +13,6 @@ MODEL_VERSIONS = [
"lm_v1", # Bepler & Berger 2019
]
print(dscript.__version__)
def test_get_state_dict():
for mv in MODEL_VERSIONS: