mirror of
https://github.com/samsledje/D-SCRIPT.git
synced 2026-06-04 15:04:24 +08:00
100 lines
2.9 KiB
Python
100 lines
2.9 KiB
Python
import logging as logg
|
|
import os
|
|
import sys
|
|
from urllib.error import HTTPError
|
|
|
|
import torch
|
|
|
|
from .models.contact import ContactCNN
|
|
from .models.embedding import FullyConnectedEmbed, SkipLSTM
|
|
from .models.interaction import ModelInteraction
|
|
from .utils import get_local_or_download
|
|
|
|
|
|
def build_lm_1(state_dict_path):
|
|
"""
|
|
:meta private:
|
|
"""
|
|
model = SkipLSTM(21, 100, 1024, 3)
|
|
state_dict = torch.load(state_dict_path)
|
|
model.load_state_dict(state_dict)
|
|
model = model.eval()
|
|
return model
|
|
|
|
|
|
def build_human_1(state_dict_path):
|
|
"""
|
|
:meta private:
|
|
"""
|
|
embModel = FullyConnectedEmbed(6165, 100, 0.5)
|
|
conModel = ContactCNN(100, 50, 7)
|
|
model = ModelInteraction(
|
|
embModel,
|
|
conModel,
|
|
use_cuda=True,
|
|
do_w=True,
|
|
do_pool=True,
|
|
do_sigmoid=True,
|
|
pool_size=9,
|
|
)
|
|
state_dict = torch.load(state_dict_path)
|
|
model.load_state_dict(state_dict)
|
|
model = model.eval()
|
|
return model
|
|
|
|
|
|
VALID_MODELS = {"lm_v1": build_lm_1, "human_v1": build_human_1}
|
|
|
|
|
|
def get_state_dict(version="human_v1", verbose=True):
|
|
"""
|
|
Download a pre-trained model if not already exists on local device.
|
|
|
|
:param version: Version of trained model to download [default: human_1]
|
|
:type version: str
|
|
:param verbose: Print model download status on stdout [default: True]
|
|
:type verbose: bool
|
|
:return: Path to state dictionary for pre-trained language model
|
|
:rtype: str
|
|
"""
|
|
state_dict_basename = f"dscript_{version}.pt"
|
|
state_dict_basedir = os.path.dirname(os.path.realpath(__file__))
|
|
state_dict_fullname = f"{state_dict_basedir}/{state_dict_basename}"
|
|
state_dict_url = (
|
|
f"http://cb.csail.mit.edu/cb/dscript/data/models/{state_dict_basename}"
|
|
)
|
|
try:
|
|
if verbose:
|
|
logg.info(f"Downloading model {version} from {state_dict_url}...")
|
|
get_local_or_download(state_dict_fullname, state_dict_url)
|
|
except HTTPError as e:
|
|
logg.error("Unable to download model - {}".format(e))
|
|
sys.exit(1)
|
|
return state_dict_fullname
|
|
|
|
|
|
def get_pretrained(version="human_v1", verbose=True):
|
|
"""
|
|
Get pre-trained model object.
|
|
|
|
Currently Available Models
|
|
==========================
|
|
|
|
See the `documentation <https://d-script.readthedocs.io/en/main/data.html#trained-models>`_ for most up-to-date list.
|
|
|
|
- ``lm_v1`` - Language model from `Bepler & Berger <https://github.com/tbepler/protein-sequence-embedding-iclr2019>`_.
|
|
- ``human_v1`` - Human trained model from D-SCRIPT manuscript.
|
|
|
|
Default: ``human_v1``
|
|
|
|
:param version: Version of pre-trained model to get
|
|
:type version: str
|
|
:return: Pre-trained model
|
|
:rtype: dscript.models.*
|
|
"""
|
|
if version not in VALID_MODELS:
|
|
raise ValueError("Model {} does not exist".format(version))
|
|
|
|
state_dict_path = get_state_dict(version, verbose=verbose)
|
|
return VALID_MODELS[version](state_dict_path)
|