Files
D-SCRIPT/dscript/language_model.py

116 lines
3.6 KiB
Python

import logging as logg
import sys
import os
import subprocess as sp
import random
import torch
import h5py
from Bio import SeqIO
from tqdm import tqdm
from .alphabets import Uniprot21
from .pretrained import get_pretrained
from .models.embedding import SkipLSTM
from datetime import datetime
def lm_embed(sequence, use_cuda=False, verbose=True):
"""
Embed a single sequence using pre-trained language model from `Bepler & Berger <https://github.com/tbepler/protein-sequence-embedding-iclr2019>`_.
:param sequence: Input sequence to be embedded
:type sequence: str
:param use_cuda: Whether to generate embeddings using GPU device [default: False]
:type use_cuda: bool
:return: Embedded sequence
:rtype: torch.Tensor
"""
model = get_pretrained("lm_v1", verbose=verbose)
torch.nn.init.normal_(model.proj.weight)
model.proj.bias = torch.nn.Parameter(torch.zeros(100))
if use_cuda:
model = model.cuda()
model.eval()
with torch.no_grad():
alphabet = Uniprot21()
es = torch.from_numpy(alphabet.encode(sequence.encode("utf-8")))
x = es.long().unsqueeze(0)
if use_cuda:
x = x.cuda()
z = model.transform(x)
return z.cpu()
def embed_from_fasta(fastaPath, outputPath, device=0, verbose=False):
"""
Embed sequences using pre-trained language model from `Bepler & Berger <https://github.com/tbepler/protein-sequence-embedding-iclr2019>`_.
:param fastaPath: Input sequence file (``.fasta`` format)
:type fastaPath: str
:param outputPath: Output embedding file (``.h5`` format)
:type outputPath: str
:param device: Compute device to use for embeddings [default: 0]
:type device: int
:param verbose: Print embedding progress
:type verbose: bool
"""
use_cuda = (device >= 0) and torch.cuda.is_available()
if use_cuda:
torch.cuda.set_device(device)
if verbose:
logg.info(
f"Using CUDA device {device} - {torch.cuda.get_device_name(device)}"
)
else:
if verbose:
logg.info("Using CPU")
if verbose:
logg.info("Loading Model...")
model = get_pretrained("lm_v1")
torch.nn.init.normal_(model.proj.weight)
model.proj.bias = torch.nn.Parameter(torch.zeros(100))
if use_cuda:
model = model.cuda()
model.eval()
if verbose:
logg.info("Loading Sequences...")
seq_records = list(SeqIO.parse(fastaPath, "fasta"))
alphabet = Uniprot21()
encoded_seqs = {}
for rec in tqdm(seq_records):
es = torch.from_numpy(alphabet.encode(rec.seq.encode("utf-8")))
if use_cuda:
es = es.cuda()
encoded_seqs[rec.name] = es
if verbose:
num_seqs = len(encoded_seqs)
logg.info("{} Sequences Loaded".format(num_seqs))
logg.info(
"Approximate Storage Required (varies by average sequence length): ~{}GB".format(
num_seqs * (1 / 125)
)
)
logg.info("Storing to {}...".format(outputPath))
with h5py.File(outputPath, "w") as h5fi, torch.no_grad():
try:
for req in tqdm(seq_records, total=len(seq_records)):
if req.name not in h5fi:
enc = alphabet.encode(req.seq.encode("utf-8"))
x = torch.from_numpy(enc).long().unsqueeze(0)
if use_cuda:
x = x.cuda()
z = model.transform(x)
h5fi.create_dataset(
req.name, data=z.cpu().numpy(), compression="lzf"
)
except KeyboardInterrupt:
h5fi.close()
sys.exit(1)