mirror of
https://github.com/samsledje/D-SCRIPT.git
synced 2026-06-04 15:04:24 +08:00
update language model to fix issue #35
This commit is contained in:
@@ -7,10 +7,15 @@
|
||||
|
||||
## v0.2
|
||||
|
||||
### v0.2.1
|
||||
### v0.2.2
|
||||
- Resolve #35 to use `require_dataset` -- can now add multiple .fasta files to the same h5 file
|
||||
- Update pretrained API and docs to include Topsy-Turvy
|
||||
- Add retry decorator to get_pretrained if download fails
|
||||
|
||||
### v0.2.1: 2022-06-28 -- Bug fixes
|
||||
- Add biopython to setup.py
|
||||
|
||||
### v0.2.0
|
||||
### v0.2.0: 2022-06-24 -- Integration of Topsy Turvy
|
||||
|
||||
- Integrate Topsy-Turvy to allow for top-down supervision
|
||||
- Use utils.log function across all commands
|
||||
|
||||
@@ -94,7 +94,7 @@ def embed_from_fasta(fastaPath, outputPath, device=0, verbose=False):
|
||||
)
|
||||
)
|
||||
|
||||
h5fi = h5py.File(outputPath, "w")
|
||||
h5fi = h5py.File(outputPath, "a")
|
||||
|
||||
log("# Storing to {}...".format(outputPath))
|
||||
with torch.no_grad():
|
||||
@@ -103,9 +103,13 @@ def embed_from_fasta(fastaPath, outputPath, device=0, verbose=False):
|
||||
if name not in h5fi:
|
||||
x = x.long().unsqueeze(0)
|
||||
z = model.transform(x)
|
||||
h5fi.create_dataset(
|
||||
name, data=z.cpu().numpy(), compression="lzf"
|
||||
dset = h5fi.require_dataset(
|
||||
name,
|
||||
shape=z.shape,
|
||||
dtype="float32",
|
||||
compression="lzf",
|
||||
)
|
||||
dset[:] = z.cpu().numpy()
|
||||
except KeyboardInterrupt:
|
||||
h5fi.close()
|
||||
sys.exit(1)
|
||||
|
||||
35
dscript/tests/test_language_model.py
Normal file
35
dscript/tests/test_language_model.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import os
|
||||
import shutil
|
||||
import subprocess as sp
|
||||
from Bio import SeqIO
|
||||
|
||||
from dscript.language_model import (
|
||||
lm_embed,
|
||||
embed_from_fasta,
|
||||
)
|
||||
|
||||
|
||||
class TestLanguageModel:
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
cmd = "python setup.py install"
|
||||
proc = sp.Popen(cmd.split())
|
||||
proc.wait()
|
||||
os.makedirs("./tmp-dscript-testing/", exist_ok=True)
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
shutil.rmtree("./tmp-dscript-testing/")
|
||||
|
||||
def test_lm_embed(self):
|
||||
seqs = list(SeqIO.parse("dscript/tests/test.fasta", "fasta"))
|
||||
for seqrec in seqs:
|
||||
x = lm_embed(str(seqrec.seq))
|
||||
assert x.shape[1] == len(seqrec.seq)
|
||||
|
||||
def embed_from_fasta(self):
|
||||
embed_from_fasta(
|
||||
"dscript/tests/test.fasta",
|
||||
"tmp-dscript-testing/test_embed.h5",
|
||||
verbose=True,
|
||||
)
|
||||
@@ -1,4 +1,3 @@
|
||||
import dscript
|
||||
from pathlib import Path
|
||||
|
||||
from dscript.pretrained import (
|
||||
@@ -14,8 +13,6 @@ MODEL_VERSIONS = [
|
||||
"lm_v1", # Bepler & Berger 2019
|
||||
]
|
||||
|
||||
print(dscript.__version__)
|
||||
|
||||
|
||||
def test_get_state_dict():
|
||||
for mv in MODEL_VERSIONS:
|
||||
|
||||
Reference in New Issue
Block a user