mirror of
https://github.com/samsledje/D-SCRIPT.git
synced 2026-06-04 15:04:24 +08:00
Resolve merge
This commit is contained in:
40
.github/workflows/autorun-tests.yml
vendored
Normal file
40
.github/workflows/autorun-tests.yml
vendored
Normal file
@@ -0,0 +1,40 @@
|
||||
# This workflow will install Python dependencies, run tests and lint with a single version of Python
|
||||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
||||
|
||||
name: Python application
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python 3.7
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: "3.7"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install flake8 pytest
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||
python setup.py install
|
||||
- name: Lint with flake8
|
||||
run: |
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
||||
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
||||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
pytest
|
||||
39
.github/workflows/pypi_publish.yml
vendored
Normal file
39
.github/workflows/pypi_publish.yml
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
# This workflow will upload a Python Package using Twine when a release is created
|
||||
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
|
||||
|
||||
# This workflow uses actions that are not certified by GitHub.
|
||||
# They are provided by a third-party and are governed by
|
||||
# separate terms of service, privacy policy, and support
|
||||
# documentation.
|
||||
|
||||
name: Upload Python Package
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: '3.x'
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||
- name: Build package
|
||||
run: python setup.py sdist bdist_wheel
|
||||
- name: Publish package
|
||||
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
|
||||
with:
|
||||
user: __token__
|
||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||
@@ -16,16 +16,17 @@ repos:
|
||||
rev: 21.6b0
|
||||
hooks:
|
||||
- id: black
|
||||
language_version: python3.8
|
||||
language_version: python3.7
|
||||
additional_dependencies: ['click==8.0.4']
|
||||
- repo: https://gitlab.com/pycqa/flake8
|
||||
rev: 3.9.2
|
||||
hooks:
|
||||
- id: flake8
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: pytest-check
|
||||
name: pytest-check
|
||||
entry: pytest
|
||||
language: conda
|
||||
pass_filenames: false
|
||||
always_run: true
|
||||
#- repo: local
|
||||
# hooks:
|
||||
# - id: pytest-check
|
||||
# name: pytest-check
|
||||
# entry: pytest
|
||||
# language: conda
|
||||
# pass_filenames: false
|
||||
# always_run: true
|
||||
|
||||
@@ -7,6 +7,9 @@
|
||||
|
||||
## v0.2
|
||||
|
||||
### v0.2.1
|
||||
- Add biopython to setup.py
|
||||
|
||||
### v0.2.0
|
||||
|
||||
- Integrate Topsy-Turvy to allow for top-down supervision
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.2.1-dev"
|
||||
__version__ = "0.2.2-dev"
|
||||
__citation__ = """Sledzieski, Singh, Cowen, Berger. "D-SCRIPT translates genome to phenome with sequence-based, structure-aware, genome-scale predictions of protein-protein interactions." Cell Systems 12, no. 10 (2021): 969-982.
|
||||
|
||||
Devkota, Singh, Sledzieski, Berger, Cowen, Topsy-Turvy: integrating a global view into sequence-based PPI prediction, Bioinformatics, In Press."""
|
||||
|
||||
@@ -4,6 +4,16 @@ D-SCRIPT: Structure Aware PPI Prediction
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from typing import Union
|
||||
|
||||
from .commands.embed import EmbeddingArguments
|
||||
from .commands.evaluate import EvaluateArguments
|
||||
from .commands.predict import PredictionArguments
|
||||
from .commands.train import TrainArguments
|
||||
|
||||
DScriptArguments = Union[
|
||||
EmbeddingArguments, EvaluateArguments, PredictionArguments, TrainArguments
|
||||
]
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
@@ -56,7 +66,7 @@ def main():
|
||||
module.add_args(sp)
|
||||
sp.set_defaults(cmd=name)
|
||||
|
||||
args = parser.parse_args()
|
||||
args: DScriptArguments = parser.parse_args()
|
||||
oc = OmegaConf.create(vars(args))
|
||||
modules[args.cmd].main(oc)
|
||||
|
||||
|
||||
@@ -2,12 +2,23 @@
|
||||
Generate new embeddings using pre-trained language model.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import logging as logg
|
||||
import sys
|
||||
|
||||
from dscript.language_model import embed_from_fasta
|
||||
|
||||
from typing import Callable, NamedTuple
|
||||
|
||||
|
||||
class EmbeddingArguments(NamedTuple):
|
||||
cmd: str
|
||||
device: int
|
||||
outfile: str
|
||||
seqs: str
|
||||
func: Callable[[EmbeddingArguments], None]
|
||||
|
||||
|
||||
def add_args(parser):
|
||||
"""
|
||||
|
||||
@@ -2,11 +2,13 @@
|
||||
Evaluate a trained model.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import sys
|
||||
import logging as logg
|
||||
from typing import Callable, NamedTuple
|
||||
|
||||
import h5py
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
@@ -26,6 +28,15 @@ from ..utils import load_hdf5_parallel
|
||||
matplotlib.use("Agg")
|
||||
|
||||
|
||||
class EvaluateArguments(NamedTuple):
|
||||
cmd: str
|
||||
device: int
|
||||
model: str
|
||||
embedding: str
|
||||
test: str
|
||||
func: Callable[[EvaluateArguments], None]
|
||||
|
||||
|
||||
def add_args(parser):
|
||||
"""
|
||||
Create parser for command line utility.
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Make new predictions with a pre-trained model. One of --seqs or --embeddings is required.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import datetime
|
||||
import logging as logg
|
||||
@@ -13,6 +14,8 @@ import pandas as pd
|
||||
import torch
|
||||
from scipy.special import comb
|
||||
from tqdm import tqdm
|
||||
from typing import Callable, NamedTuple, Optional
|
||||
|
||||
|
||||
from ..datamodules import CachedFasta, CachedH5
|
||||
from ..alphabets import Uniprot21
|
||||
@@ -21,6 +24,17 @@ from ..language_model import lm_embed
|
||||
from ..utils import load_hdf5_parallel
|
||||
|
||||
|
||||
class PredictionArguments(NamedTuple):
|
||||
cmd: str
|
||||
device: int
|
||||
embeddings: Optional[str]
|
||||
outfile: Optional[str]
|
||||
seqs: str
|
||||
model: str
|
||||
thresh: Optional[float]
|
||||
func: Callable[[PredictionArguments], None]
|
||||
|
||||
|
||||
def add_args(parser):
|
||||
"""
|
||||
Create parser for command line utility
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
Train a new model.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import datetime
|
||||
import gzip as gz
|
||||
@@ -14,12 +14,14 @@ import h5py
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytorch_lightning as pl
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optimizers
|
||||
from pytorch_lightning import loggers as pl_loggers
|
||||
from tqdm import tqdm
|
||||
from typing import Callable, NamedTuple, Optional
|
||||
|
||||
from ..datamodules import PPIDataModule
|
||||
|
||||
@@ -30,6 +32,36 @@ from ..models.lightning import LitInteraction
|
||||
from ..utils import config_logger
|
||||
|
||||
|
||||
class TrainArguments(NamedTuple):
|
||||
cmd: str
|
||||
device: int
|
||||
train: str
|
||||
test: str
|
||||
embedding: str
|
||||
no_augment: bool
|
||||
input_dim: int
|
||||
projection_dim: int
|
||||
dropout: float
|
||||
hidden_dim: int
|
||||
kernel_width: int
|
||||
no_w: bool
|
||||
no_sigmoid: bool
|
||||
do_pool: bool
|
||||
pool_width: int
|
||||
num_epochs: int
|
||||
batch_size: int
|
||||
weight_decay: float
|
||||
lr: float
|
||||
interaction_weight: float
|
||||
run_tt: bool
|
||||
glider_weight: float
|
||||
glider_thresh: float
|
||||
outfile: Optional[str]
|
||||
save_prefix: Optional[str]
|
||||
checkpoint: Optional[str]
|
||||
func: Callable[[TrainArguments], None]
|
||||
|
||||
|
||||
def add_args(parser):
|
||||
"""
|
||||
Create parser for command line utility.
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
from __future__ import print_function, division
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Alphabet:
|
||||
"""
|
||||
From `Bepler & Berger <https://github.com/tbepler/protein-sequence-embedding-iclr2019>`_.
|
||||
|
||||
:param chars: List of characters in alphabet
|
||||
:type chars: byte str
|
||||
:param encoding: Mapping of characters to numbers [default: encoding]
|
||||
:type encoding: np.ndarray
|
||||
:param mask: Set encoding mask [default: False]
|
||||
:type mask: bool
|
||||
:param missing: Number to use for a value outside the alphabet [default: 255]
|
||||
:type missing: int
|
||||
"""
|
||||
|
||||
def __init__(self, chars, encoding=None, mask=False, missing=255):
|
||||
self.chars = np.frombuffer(chars, dtype=np.uint8)
|
||||
self.encoding = np.zeros(256, dtype=np.uint8) + missing
|
||||
if encoding is None:
|
||||
self.encoding[self.chars] = np.arange(len(self.chars))
|
||||
self.size = len(self.chars)
|
||||
else:
|
||||
self.encoding[self.chars] = encoding
|
||||
self.size = encoding.max() + 1
|
||||
self.mask = mask
|
||||
if mask:
|
||||
self.size -= 1
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
def __getitem__(self, i):
|
||||
return chr(self.chars[i])
|
||||
|
||||
def encode(self, x):
|
||||
"""
|
||||
Encode a byte string into alphabet indices
|
||||
|
||||
:param x: Amino acid string
|
||||
:type x: byte str
|
||||
:return: Numeric encoding
|
||||
:rtype: np.ndarray
|
||||
"""
|
||||
x = np.frombuffer(x, dtype=np.uint8)
|
||||
return self.encoding[x]
|
||||
|
||||
def decode(self, x):
|
||||
"""
|
||||
Decode numeric encoding to byte string of this alphabet
|
||||
|
||||
:param x: Numeric encoding
|
||||
:type x: np.ndarray
|
||||
:return: Amino acid string
|
||||
:rtype: byte str
|
||||
"""
|
||||
string = self.chars[x]
|
||||
return string.tobytes()
|
||||
|
||||
|
||||
class Uniprot21(Alphabet):
|
||||
"""
|
||||
Uniprot 21 Amino Acid Encoding.
|
||||
|
||||
From `Bepler & Berger <https://github.com/tbepler/protein-sequence-embedding-iclr2019>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, mask=False):
|
||||
chars = b"ARNDCQEGHILKMFPSTWYVXOUBZ"
|
||||
encoding = np.arange(len(chars))
|
||||
encoding[21:] = [11, 4, 20, 20] # encode 'OUBZ' as synonyms
|
||||
super(Uniprot21, self).__init__(
|
||||
chars, encoding=encoding, mask=mask, missing=20
|
||||
)
|
||||
@@ -1,132 +0,0 @@
|
||||
"""
|
||||
Contact model classes.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.functional as F
|
||||
|
||||
|
||||
class FullyConnected(nn.Module):
|
||||
"""
|
||||
Performs part 1 of Contact Prediction Module. Takes embeddings from Projection module and produces broadcast tensor.
|
||||
|
||||
Input embeddings of dimension :math:`d` are combined into a :math:`2d` length MLP input :math:`z_{cat}`, where :math:`z_{cat} = [z_0 \\ominus z_1 | z_0 \\odot z_1]`
|
||||
|
||||
:param embed_dim: Output dimension of `dscript.models.embedding <#module-dscript.models.embedding>`_ model :math:`d` [default: 100]
|
||||
:type embed_dim: int
|
||||
:param hidden_dim: Hidden dimension :math:`h` [default: 50]
|
||||
:type hidden_dim: int
|
||||
:param activation: Activation function for broadcast tensor [default: torch.nn.ReLU()]
|
||||
:type activation: torch.nn.Module
|
||||
"""
|
||||
|
||||
def __init__(self, embed_dim, hidden_dim, activation=nn.ReLU()):
|
||||
super(FullyConnected, self).__init__()
|
||||
|
||||
self.D = embed_dim
|
||||
self.H = hidden_dim
|
||||
self.conv = nn.Conv2d(2 * self.D, self.H, 1)
|
||||
self.batchnorm = nn.BatchNorm2d(self.H)
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, z0, z1):
|
||||
"""
|
||||
:param z0: Projection module embedding :math:`(b \\times N \\times d)`
|
||||
:type z0: torch.Tensor
|
||||
:param z1: Projection module embedding :math:`(b \\times M \\times d)`
|
||||
:type z1: torch.Tensor
|
||||
:return: Predicted broadcast tensor :math:`(b \\times N \\times M \\times h)`
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
|
||||
# z0 is (b,N,d), z1 is (b,M,d)
|
||||
z0 = z0.transpose(1, 2)
|
||||
z1 = z1.transpose(1, 2)
|
||||
# z0 is (b,d,N), z1 is (b,d,M)
|
||||
|
||||
z_dif = torch.abs(z0.unsqueeze(3) - z1.unsqueeze(2))
|
||||
z_mul = z0.unsqueeze(3) * z1.unsqueeze(2)
|
||||
z_cat = torch.cat([z_dif, z_mul], 1)
|
||||
|
||||
b = self.conv(z_cat)
|
||||
b = self.activation(b)
|
||||
b = self.batchnorm(b)
|
||||
|
||||
return b
|
||||
|
||||
|
||||
class ContactCNN(nn.Module):
|
||||
"""
|
||||
Residue Contact Prediction Module. Takes embeddings from Projection module and produces contact map, output of Contact module.
|
||||
|
||||
:param embed_dim: Output dimension of `dscript.models.embedding <#module-dscript.models.embedding>`_ model :math:`d` [default: 100]
|
||||
:type embed_dim: int
|
||||
:param hidden_dim: Hidden dimension :math:`h` [default: 50]
|
||||
:type hidden_dim: int
|
||||
:param width: Width of convolutional filter :math:`2w+1` [default: 7]
|
||||
:type width: int
|
||||
:param activation: Activation function for final contact map [default: torch.nn.Sigmoid()]
|
||||
:type activation: torch.nn.Module
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, embed_dim=100, hidden_dim=50, width=7, activation=nn.Sigmoid()
|
||||
):
|
||||
super(ContactCNN, self).__init__()
|
||||
|
||||
self.hidden = FullyConnected(embed_dim, hidden_dim)
|
||||
self.conv = nn.Conv2d(hidden_dim, 1, width, padding=width // 2)
|
||||
self.batchnorm = nn.BatchNorm2d(1)
|
||||
self.activation = activation
|
||||
self.clip()
|
||||
|
||||
def clip(self):
|
||||
"""
|
||||
Force the convolutional layer to be transpose invariant.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
|
||||
w = self.conv.weight
|
||||
self.conv.weight.data[:] = 0.5 * (w + w.transpose(2, 3))
|
||||
|
||||
def forward(self, z0, z1):
|
||||
"""
|
||||
:param z0: Projection module embedding :math:`(b \\times N \\times d)`
|
||||
:type z0: torch.Tensor
|
||||
:param z1: Projection module embedding :math:`(b \\times M \\times d)`
|
||||
:type z1: torch.Tensor
|
||||
:return: Predicted contact map :math:`(b \\times N \\times M)`
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
B = self.broadcast(z0, z1)
|
||||
return self.predict(B)
|
||||
|
||||
def broadcast(self, z0, z1):
|
||||
"""
|
||||
Calls `dscript.models.contact.FullyConnected <#module-dscript.models.contact.FullyConnected>`_.
|
||||
|
||||
:param z0: Projection module embedding :math:`(b \\times N \\times d)`
|
||||
:type z0: torch.Tensor
|
||||
:param z1: Projection module embedding :math:`(b \\times M \\times d)`
|
||||
:type z1: torch.Tensor
|
||||
:return: Predicted contact broadcast tensor :math:`(b \\times N \\times M \\times h)`
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
B = self.hidden(z0, z1)
|
||||
return B
|
||||
|
||||
def predict(self, B):
|
||||
"""
|
||||
Predict contact map from broadcast tensor.
|
||||
|
||||
:param B: Predicted contact broadcast :math:`(b \\times N \\times M \\times h)`
|
||||
:type B: torch.Tensor
|
||||
:return: Predicted contact map :math:`(b \\times N \\times M)`
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
C = self.conv(B)
|
||||
C = self.batchnorm(C)
|
||||
C = self.activation(C)
|
||||
return C
|
||||
@@ -1,185 +0,0 @@
|
||||
"""
|
||||
Embedding model classes.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import PackedSequence
|
||||
|
||||
|
||||
class IdentityEmbed(nn.Module):
|
||||
"""
|
||||
Does not reduce the dimension of the language model embeddings, just passes them through to the contact model.
|
||||
"""
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
:param x: Input language model embedding :math:`(b \\times N \\times d_0)`
|
||||
:type x: torch.Tensor
|
||||
:return: Same embedding
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
return x
|
||||
|
||||
|
||||
class FullyConnectedEmbed(nn.Module):
|
||||
"""
|
||||
Protein Projection Module. Takes embedding from language model and outputs low-dimensional interaction aware projection.
|
||||
|
||||
:param nin: Size of language model output
|
||||
:type nin: int
|
||||
:param nout: Dimension of projection
|
||||
:type nout: int
|
||||
:param dropout: Proportion of weights to drop out [default: 0.5]
|
||||
:type dropout: float
|
||||
:param activation: Activation for linear projection model
|
||||
:type activation: torch.nn.Module
|
||||
"""
|
||||
|
||||
def __init__(self, nin, nout, dropout=0.5, activation=nn.ReLU()):
|
||||
super(FullyConnectedEmbed, self).__init__()
|
||||
self.nin = nin
|
||||
self.nout = nout
|
||||
self.dropout_p = dropout
|
||||
|
||||
self.transform = nn.Linear(nin, nout)
|
||||
self.drop = nn.Dropout(p=self.dropout_p)
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
:param x: Input language model embedding :math:`(b \\times N \\times d_0)`
|
||||
:type x: torch.Tensor
|
||||
:return: Low dimensional projection of embedding
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
t = self.transform(x)
|
||||
t = self.activation(t)
|
||||
t = self.drop(t)
|
||||
return t
|
||||
|
||||
|
||||
class SkipLSTM(nn.Module):
|
||||
"""
|
||||
Language model from `Bepler & Berger <https://github.com/tbepler/protein-sequence-embedding-iclr2019>`_.
|
||||
|
||||
Loaded with pre-trained weights in embedding function.
|
||||
|
||||
:param nin: Input dimension of amino acid one-hot [default: 21]
|
||||
:type nin: int
|
||||
:param nout: Output dimension of final layer [default: 100]
|
||||
:type nout: int
|
||||
:param hidden_dim: Size of hidden dimension [default: 1024]
|
||||
:type hidden_dim: int
|
||||
:param num_layers: Number of stacked LSTM models [default: 3]
|
||||
:type num_layers: int
|
||||
:param dropout: Proportion of weights to drop out [default: 0]
|
||||
:type dropout: float
|
||||
:param bidirectional: Whether to use biLSTM vs. LSTM
|
||||
:type bidirectional: bool
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nin=21,
|
||||
nout=100,
|
||||
hidden_dim=1024,
|
||||
num_layers=3,
|
||||
dropout=0,
|
||||
bidirectional=True,
|
||||
):
|
||||
super(SkipLSTM, self).__init__()
|
||||
|
||||
self.nin = nin
|
||||
self.nout = nout
|
||||
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
dim = nin
|
||||
for i in range(num_layers):
|
||||
f = nn.LSTM(
|
||||
dim,
|
||||
hidden_dim,
|
||||
1,
|
||||
batch_first=True,
|
||||
bidirectional=bidirectional,
|
||||
)
|
||||
self.layers.append(f)
|
||||
if bidirectional:
|
||||
dim = 2 * hidden_dim
|
||||
else:
|
||||
dim = hidden_dim
|
||||
|
||||
n = hidden_dim * num_layers + nin
|
||||
if bidirectional:
|
||||
n = 2 * hidden_dim * num_layers + nin
|
||||
|
||||
self.proj = nn.Linear(n, nout)
|
||||
|
||||
def to_one_hot(self, x):
|
||||
"""
|
||||
Transform numeric encoded amino acid vector to one-hot encoded vector
|
||||
|
||||
:param x: Input numeric amino acid encoding :math:`(N)`
|
||||
:type x: torch.Tensor
|
||||
:return: One-hot encoding vector :math:`(N \\times n_{in})`
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
packed = type(x) is PackedSequence
|
||||
if packed:
|
||||
one_hot = x.data.new(x.data.size(0), self.nin).float().zero_()
|
||||
one_hot.scatter_(1, x.data.unsqueeze(1), 1)
|
||||
one_hot = PackedSequence(one_hot, x.batch_sizes)
|
||||
else:
|
||||
one_hot = x.new(x.size(0), x.size(1), self.nin).float().zero_()
|
||||
one_hot.scatter_(2, x.unsqueeze(2), 1)
|
||||
return one_hot
|
||||
|
||||
def transform(self, x):
|
||||
"""
|
||||
:param x: Input numeric amino acid encoding :math:`(N)`
|
||||
:type x: torch.Tensor
|
||||
:return: Concatenation of all hidden layers :math:`(N \\times (n_{in} + 2 \\times \\text{num_layers} \\times \\text{hidden_dim}))`
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
one_hot = self.to_one_hot(x)
|
||||
hs = [one_hot] # []
|
||||
h_ = one_hot
|
||||
for f in self.layers:
|
||||
h, _ = f(h_)
|
||||
# h = self.dropout(h)
|
||||
hs.append(h)
|
||||
h_ = h
|
||||
if type(x) is PackedSequence:
|
||||
h = torch.cat([z.data for z in hs], 1)
|
||||
h = PackedSequence(h, x.batch_sizes)
|
||||
else:
|
||||
h = torch.cat([z for z in hs], 2)
|
||||
return h
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
:meta private:
|
||||
"""
|
||||
one_hot = self.to_one_hot(x)
|
||||
hs = [one_hot]
|
||||
h_ = one_hot
|
||||
|
||||
for f in self.layers:
|
||||
h, _ = f(h_)
|
||||
# h = self.dropout(h)
|
||||
hs.append(h)
|
||||
h_ = h
|
||||
|
||||
if type(x) is PackedSequence:
|
||||
h = torch.cat([z.data for z in hs], 1)
|
||||
z = self.proj(h)
|
||||
z = PackedSequence(z, x.batch_sizes)
|
||||
else:
|
||||
h = torch.cat([z for z in hs], 2)
|
||||
z = self.proj(h.view(-1, h.size(2)))
|
||||
z = z.view(x.size(0), x.size(1), -1)
|
||||
|
||||
return z
|
||||
@@ -1,78 +0,0 @@
|
||||
def parse(f, comment="#"):
|
||||
"""
|
||||
Parse a file in ``.fasta`` format.
|
||||
|
||||
:param f: Input file object
|
||||
:type f: _io.TextIOWrapper
|
||||
:param comment: Character used for comments
|
||||
:type comment: str
|
||||
|
||||
:return: names, sequence
|
||||
:rtype: list[str], list[str]
|
||||
"""
|
||||
starter = ">"
|
||||
empty = ""
|
||||
if "b" in f.mode:
|
||||
comment = b"#"
|
||||
starter = b">"
|
||||
empty = b""
|
||||
names = []
|
||||
sequences = []
|
||||
name = None
|
||||
sequence = []
|
||||
for line in f:
|
||||
if line.startswith(comment):
|
||||
continue
|
||||
line = line.strip()
|
||||
if line.startswith(starter):
|
||||
if name is not None:
|
||||
names.append(name)
|
||||
sequences.append(empty.join(sequence))
|
||||
name = line[1:]
|
||||
sequence = []
|
||||
else:
|
||||
sequence.append(line.upper())
|
||||
if name is not None:
|
||||
names.append(name)
|
||||
sequences.append(empty.join(sequence))
|
||||
|
||||
return names, sequences
|
||||
|
||||
|
||||
def parse_directory(directory, extension=".seq"):
|
||||
"""
|
||||
Parse all files in a directory ending with ``extension``.
|
||||
|
||||
:param directory: Input directory
|
||||
:type directory: str
|
||||
:param extension: Extension of all files to read in
|
||||
:type extension: str
|
||||
|
||||
:return: names, sequence
|
||||
:rtype: list[str], list[str]
|
||||
"""
|
||||
names = []
|
||||
sequences = []
|
||||
|
||||
for seqPath in os.listdir(directory):
|
||||
if seqPath.endswith(extension):
|
||||
n, s = parse(open(f"{directory}/{seqPath}", "rb"))
|
||||
names.append(n[0].decode("utf-8").strip())
|
||||
sequences.append(s[0].decode("utf-8").strip())
|
||||
return names, sequences
|
||||
|
||||
|
||||
def write(nam, seq, f):
|
||||
"""
|
||||
Write a file in ``.fasta`` format.
|
||||
|
||||
:param nam: List of names
|
||||
:type nam: list[str]
|
||||
:param seq: List of sequences
|
||||
:type seq: list[str]
|
||||
:param f: Output file object
|
||||
:type f: _io.TextIOWrapper
|
||||
"""
|
||||
for n, s in zip(nam, seq):
|
||||
f.write(">{}\n".format(n))
|
||||
f.write("{}\n".format(s))
|
||||
@@ -1,221 +0,0 @@
|
||||
"""
|
||||
Interaction model classes.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.functional as F
|
||||
|
||||
|
||||
class LogisticActivation(nn.Module):
|
||||
"""
|
||||
Implementation of Generalized Sigmoid
|
||||
Applies the element-wise function:
|
||||
|
||||
:math:`\\sigma(x) = \\frac{1}{1 + \\exp(-k(x-x_0))}`
|
||||
|
||||
:param x0: The value of the sigmoid midpoint
|
||||
:type x0: float
|
||||
:param k: The slope of the sigmoid - trainable - :math:`k \\geq 0`
|
||||
:type k: float
|
||||
:param train: Whether :math:`k` is a trainable parameter
|
||||
:type train: bool
|
||||
"""
|
||||
|
||||
def __init__(self, x0=0, k=1, train=False):
|
||||
super(LogisticActivation, self).__init__()
|
||||
self.x0 = x0
|
||||
self.k = nn.Parameter(torch.FloatTensor([float(k)]))
|
||||
self.k.requiresGrad = train
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Applies the function to the input elementwise
|
||||
|
||||
:param x: :math:`(N \\times *)` where :math:`*` means, any number of additional dimensions
|
||||
:type x: torch.Tensor
|
||||
:return: :math:`(N \\times *)`, same shape as the input
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
out = torch.clamp(
|
||||
1 / (1 + torch.exp(-self.k * (x - self.x0))), min=0, max=1
|
||||
).squeeze()
|
||||
return out
|
||||
|
||||
def clip(self):
|
||||
"""
|
||||
Restricts sigmoid slope :math:`k` to be greater than or equal to 0, if :math:`k` is trained.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
self.k.data.clamp_(min=0)
|
||||
|
||||
|
||||
class ModelInteraction(nn.Module):
|
||||
"""
|
||||
Main D-SCRIPT model. Contains an embedding and contact model and offers access to those models. Computes pooling operations on contact map to generate interaction probability.
|
||||
|
||||
:param embedding: Embedding model
|
||||
:type embedding: dscript.models.embedding.FullyConnectedEmbed
|
||||
:param contact: Contact model
|
||||
:type contact: dscript.models.contact.ContactCNN
|
||||
:param use_cuda: Whether the model should be run on GPU
|
||||
:type use_cuda: bool
|
||||
:param pool_size: width of max-pool [default 9]
|
||||
:type pool_size: bool
|
||||
:param theta_init: initialization value of :math:`\\theta` for weight matrix [default: 1]
|
||||
:type theta_init: float
|
||||
:param lambda_init: initialization value of :math:`\\lambda` for weight matrix [default: 0]
|
||||
:type lambda_init: float
|
||||
:param gamma_init: initialization value of :math:`\\gamma` for global pooling [default: 0]
|
||||
:type gamma_init: float
|
||||
:param use_W: whether to use the weighting matrix [default: True]
|
||||
:type use_W: bool
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding,
|
||||
contact,
|
||||
pool_size=9,
|
||||
theta_init=1,
|
||||
lambda_init=0,
|
||||
gamma_init=0,
|
||||
use_W=True,
|
||||
):
|
||||
super(ModelInteraction, self).__init__()
|
||||
self.use_W = use_W
|
||||
self.activation = LogisticActivation(x0=0.5, k=20)
|
||||
|
||||
self.embedding = embedding
|
||||
self.contact = contact
|
||||
|
||||
if self.use_W:
|
||||
self.theta = nn.Parameter(torch.FloatTensor([theta_init]))
|
||||
self.lambda_ = nn.Parameter(torch.FloatTensor([lambda_init]))
|
||||
|
||||
self.maxPool = nn.MaxPool2d(pool_size, padding=pool_size // 2)
|
||||
self.gamma = nn.Parameter(torch.FloatTensor([gamma_init]))
|
||||
|
||||
self.clip()
|
||||
|
||||
def clip(self):
|
||||
"""
|
||||
Clamp model values
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
self.contact.clip()
|
||||
|
||||
if self.use_W:
|
||||
self.theta.data.clamp_(min=0, max=1)
|
||||
self.lambda_.data.clamp_(min=0)
|
||||
|
||||
self.gamma.data.clamp_(min=0)
|
||||
|
||||
def embed(self, z):
|
||||
"""
|
||||
Project down input language model embeddings into low dimension using projection module
|
||||
|
||||
:param z: Language model embedding :math:`(b \\times N \\times d_0)`
|
||||
:type z: torch.Tensor
|
||||
:return: D-SCRIPT projection :math:`(b \\times N \\times d)`
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
if self.embedding is None:
|
||||
return z
|
||||
else:
|
||||
return self.embedding(z)
|
||||
|
||||
def cpred(self, z0, z1):
|
||||
"""
|
||||
Project down input language model embeddings into low dimension using projection module
|
||||
|
||||
:param z0: Language model embedding :math:`(b \\times N \\times d_0)`
|
||||
:type z0: torch.Tensor
|
||||
:param z1: Language model embedding :math:`(b \\times N \\times d_0)`
|
||||
:type z1: torch.Tensor
|
||||
:return: Predicted contact map :math:`(b \\times N \\times M)`
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
e0 = self.embed(z0)
|
||||
e1 = self.embed(z1)
|
||||
B = self.contact.broadcast(e0, e1)
|
||||
C = self.contact.predict(B)
|
||||
return C
|
||||
|
||||
def map_predict(self, z0, z1):
|
||||
"""
|
||||
Project down input language model embeddings into low dimension using projection module
|
||||
|
||||
:param z0: Language model embedding :math:`(b \\times N \\times d_0)`
|
||||
:type z0: torch.Tensor
|
||||
:param z1: Language model embedding :math:`(b \\times N \\times d_0)`
|
||||
:type z1: torch.Tensor
|
||||
:return: Predicted contact map, predicted probability of interaction :math:`(b \\times N \\times d_0), (1)`
|
||||
:rtype: torch.Tensor, torch.Tensor
|
||||
"""
|
||||
|
||||
C = self.cpred(z0, z1)
|
||||
|
||||
if self.use_W:
|
||||
# Create contact weighting matrix
|
||||
N, M = C.shape[2:]
|
||||
|
||||
x1 = torch.from_numpy(
|
||||
-1
|
||||
* ((np.arange(N) + 1 - ((N + 1) / 2)) / (-1 * ((N + 1) / 2)))
|
||||
** 2
|
||||
).float()
|
||||
if self.gamma.device.type == "cuda":
|
||||
x1 = x1.cuda()
|
||||
x1 = torch.exp(self.lambda_ * x1)
|
||||
|
||||
x2 = torch.from_numpy(
|
||||
-1
|
||||
* ((np.arange(M) + 1 - ((M + 1) / 2)) / (-1 * ((M + 1) / 2)))
|
||||
** 2
|
||||
).float()
|
||||
if self.gamma.device.type == "cuda":
|
||||
x2 = x2.cuda()
|
||||
x2 = torch.exp(self.lambda_ * x2)
|
||||
|
||||
W = x1.unsqueeze(1) * x2
|
||||
W = (1 - self.theta) * W + self.theta
|
||||
|
||||
yhat = C * W
|
||||
|
||||
else:
|
||||
yhat = C
|
||||
|
||||
yhat = self.maxPool(yhat)
|
||||
|
||||
# Mean of contact predictions where p_ij > mu + gamma*sigma
|
||||
mu = torch.mean(yhat)
|
||||
sigma = torch.var(yhat)
|
||||
Q = torch.relu(yhat - mu - (self.gamma * sigma))
|
||||
phat = torch.sum(Q) / (torch.sum(torch.sign(Q)) + 1)
|
||||
phat = self.activation(phat)
|
||||
return C, phat
|
||||
|
||||
def predict(self, z0, z1):
|
||||
"""
|
||||
Project down input language model embeddings into low dimension using projection module
|
||||
|
||||
:param z0: Language model embedding :math:`(b \\times N \\times d_0)`
|
||||
:type z0: torch.Tensor
|
||||
:param z1: Language model embedding :math:`(b \\times N \\times d_0)`
|
||||
:type z1: torch.Tensor
|
||||
:return: Predicted probability of interaction
|
||||
:rtype: torch.Tensor, torch.Tensor
|
||||
"""
|
||||
_, phat = self.map_predict(z0, z1)
|
||||
return phat
|
||||
|
||||
def forward(self, z0, z1):
|
||||
"""
|
||||
:meta private:
|
||||
"""
|
||||
return self.predict(z0, z1)
|
||||
@@ -1,616 +0,0 @@
|
||||
"""
|
||||
Train a new model.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import h5py
|
||||
import datetime
|
||||
import subprocess as sp
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import gzip as gz
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torch.autograd import Variable
|
||||
from torch.utils.data import IterableDataset, DataLoader
|
||||
from sklearn.metrics import average_precision_score as average_precision
|
||||
|
||||
import dscript
|
||||
from dscript.utils import PairedDataset, collate_paired_sequences
|
||||
from dscript.models.embedding import (
|
||||
IdentityEmbed,
|
||||
FullyConnectedEmbed,
|
||||
)
|
||||
from dscript.models.contact import ContactCNN
|
||||
from dscript.models.interaction import ModelInteraction
|
||||
|
||||
|
||||
def add_args(parser):
|
||||
"""
|
||||
Create parser for command line utility.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
|
||||
data_grp = parser.add_argument_group("Data")
|
||||
proj_grp = parser.add_argument_group("Projection Module")
|
||||
contact_grp = parser.add_argument_group("Contact Module")
|
||||
inter_grp = parser.add_argument_group("Interaction Module")
|
||||
train_grp = parser.add_argument_group("Training")
|
||||
misc_grp = parser.add_argument_group("Output and Device")
|
||||
|
||||
# Data
|
||||
data_grp.add_argument("--train", help="Training data", required=True)
|
||||
data_grp.add_argument("--val", help="Validation data", required=True)
|
||||
data_grp.add_argument(
|
||||
"--embedding", help="h5 file with embedded sequences", required=True
|
||||
)
|
||||
data_grp.add_argument(
|
||||
"--no-augment",
|
||||
action="store_false",
|
||||
dest="augment",
|
||||
help="Set flag to not augment data by adding (B A) for all pairs (A B)",
|
||||
)
|
||||
|
||||
# Embedding model
|
||||
proj_grp.add_argument(
|
||||
"--projection-dim",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Dimension of embedding projection layer (default: 100)",
|
||||
)
|
||||
proj_grp.add_argument(
|
||||
"--dropout-p",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Parameter p for embedding dropout layer (default: 0.5)",
|
||||
)
|
||||
|
||||
# Contact model
|
||||
contact_grp.add_argument(
|
||||
"--hidden-dim",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of hidden units for comparison layer in contact prediction (default: 50)",
|
||||
)
|
||||
contact_grp.add_argument(
|
||||
"--kernel-width",
|
||||
type=int,
|
||||
default=7,
|
||||
help="Width of convolutional filter for contact prediction (default: 7)",
|
||||
)
|
||||
|
||||
# Interaction Model
|
||||
inter_grp.add_argument(
|
||||
"--no-w",
|
||||
action="store_false",
|
||||
dest="use_w",
|
||||
help="Don't use weight matrix in interaction prediction model",
|
||||
)
|
||||
inter_grp.add_argument(
|
||||
"--pool-width",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Size of max-pool in interaction model (default: 9)",
|
||||
)
|
||||
|
||||
# Training
|
||||
train_grp.add_argument(
|
||||
"--negative-ratio",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of negative training samples for each positive training sample (default: 10)",
|
||||
)
|
||||
train_grp.add_argument(
|
||||
"--epoch-scale",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Report heldout performance every this many epochs (default: 1)",
|
||||
)
|
||||
train_grp.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of epochs (default: 10)",
|
||||
)
|
||||
train_grp.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=25,
|
||||
help="Minibatch size (default: 25)",
|
||||
)
|
||||
train_grp.add_argument(
|
||||
"--weight-decay",
|
||||
type=float,
|
||||
default=0,
|
||||
help="L2 regularization (default: 0)",
|
||||
)
|
||||
train_grp.add_argument(
|
||||
"--lr",
|
||||
type=float,
|
||||
default=0.001,
|
||||
help="Learning rate (default: 0.001)",
|
||||
)
|
||||
train_grp.add_argument(
|
||||
"--lambda",
|
||||
dest="lambda_",
|
||||
type=float,
|
||||
default=0.35,
|
||||
help="Weight on the similarity objective (default: 0.35)",
|
||||
)
|
||||
|
||||
# Output
|
||||
misc_grp.add_argument(
|
||||
"-o", "--outfile", help="Output file path (default: stdout)"
|
||||
)
|
||||
misc_grp.add_argument(
|
||||
"--save-prefix", help="Path prefix for saving models"
|
||||
)
|
||||
misc_grp.add_argument(
|
||||
"-d", "--device", type=int, default=-1, help="Compute device to use"
|
||||
)
|
||||
misc_grp.add_argument(
|
||||
"--checkpoint", help="Checkpoint model to start training from"
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def predict_interaction(model, n0, n1, tensors, use_cuda):
|
||||
"""
|
||||
Predict whether a list of protein pairs will interact.
|
||||
|
||||
:param model: Model to be trained
|
||||
:type model: dscript.models.interaction.ModelInteraction
|
||||
:param n0: First protein names
|
||||
:type n0: list[str]
|
||||
:param n1: Second protein names
|
||||
:type n1: list[str]
|
||||
:param tensors: Dictionary of protein names to embeddings
|
||||
:type tensors: dict[str, torch.Tensor]
|
||||
:param use_cuda: Whether to use GPU
|
||||
:type use_cuda: bool
|
||||
"""
|
||||
|
||||
b = len(n0)
|
||||
|
||||
p_hat = []
|
||||
for i in range(b):
|
||||
z_a = tensors[n0[i]]
|
||||
z_b = tensors[n1[i]]
|
||||
if use_cuda:
|
||||
z_a = z_a.cuda()
|
||||
z_b = z_b.cuda()
|
||||
|
||||
p_hat.append(model.predict(z_a, z_b))
|
||||
p_hat = torch.stack(p_hat, 0)
|
||||
return p_hat
|
||||
|
||||
|
||||
def predict_cmap_interaction(model, n0, n1, tensors, use_cuda):
|
||||
"""
|
||||
Predict whether a list of protein pairs will interact, as well as their contact map.
|
||||
|
||||
:param model: Model to be trained
|
||||
:type model: dscript.models.interaction.ModelInteraction
|
||||
:param n0: First protein names
|
||||
:type n0: list[str]
|
||||
:param n1: Second protein names
|
||||
:type n1: list[str]
|
||||
:param tensors: Dictionary of protein names to embeddings
|
||||
:type tensors: dict[str, torch.Tensor]
|
||||
:param use_cuda: Whether to use GPU
|
||||
:type use_cuda: bool
|
||||
"""
|
||||
|
||||
b = len(n0)
|
||||
|
||||
p_hat = []
|
||||
c_map_mag = []
|
||||
for i in range(b):
|
||||
z_a = tensors[n0[i]]
|
||||
z_b = tensors[n1[i]]
|
||||
if use_cuda:
|
||||
z_a = z_a.cuda()
|
||||
z_b = z_b.cuda()
|
||||
|
||||
cm, ph = model.map_predict(z_a, z_b)
|
||||
p_hat.append(ph)
|
||||
c_map_mag.append(torch.mean(cm))
|
||||
p_hat = torch.stack(p_hat, 0)
|
||||
c_map_mag = torch.stack(c_map_mag, 0)
|
||||
return c_map_mag, p_hat
|
||||
|
||||
|
||||
def interaction_grad(model, n0, n1, y, tensors, use_cuda, weight=0.35):
|
||||
"""
|
||||
Compute gradient and backpropagate loss for a batch.
|
||||
|
||||
:param model: Model to be trained
|
||||
:type model: dscript.models.interaction.ModelInteraction
|
||||
:param n0: First protein names
|
||||
:type n0: list[str]
|
||||
:param n1: Second protein names
|
||||
:type n1: list[str]
|
||||
:param y: Interaction labels
|
||||
:type y: torch.Tensor
|
||||
:param tensors: Dictionary of protein names to embeddings
|
||||
:type tensors: dict[str, torch.Tensor]
|
||||
:param use_cuda: Whether to use GPU
|
||||
:type use_cuda: bool
|
||||
:param weight: Weight on the contact map magnitude objective. BCE loss is :math:`1 - \\text{weight}`.
|
||||
:type weight: float
|
||||
|
||||
:return: (Loss, number correct, mean square error, batch size)
|
||||
:rtype: (torch.Tensor, int, torch.Tensor, int)
|
||||
"""
|
||||
|
||||
c_map_mag, p_hat = predict_cmap_interaction(
|
||||
model, n0, n1, tensors, use_cuda
|
||||
)
|
||||
if use_cuda:
|
||||
y = y.cuda()
|
||||
y = Variable(y)
|
||||
|
||||
bce_loss = F.binary_cross_entropy(p_hat.float(), y.float())
|
||||
cmap_loss = torch.mean(c_map_mag)
|
||||
loss = (weight * bce_loss) + ((1 - weight) * cmap_loss)
|
||||
b = len(p_hat)
|
||||
|
||||
# backprop loss
|
||||
loss.backward()
|
||||
|
||||
if use_cuda:
|
||||
y = y.cpu()
|
||||
p_hat = p_hat.cpu()
|
||||
|
||||
with torch.no_grad():
|
||||
guess_cutoff = 0.5
|
||||
p_hat = p_hat.float()
|
||||
p_guess = (guess_cutoff * torch.ones(b) < p_hat).float()
|
||||
y = y.float()
|
||||
correct = torch.sum(p_guess == y).item()
|
||||
mse = torch.mean((y.float() - p_hat) ** 2).item()
|
||||
|
||||
return loss, correct, mse, b
|
||||
|
||||
|
||||
def interaction_eval(model, test_iterator, tensors, use_cuda):
|
||||
"""
|
||||
Evaluate test data set performance.
|
||||
|
||||
:param model: Model to be trained
|
||||
:type model: dscript.models.interaction.ModelInteraction
|
||||
:param test_iterator: Test data iterator
|
||||
:type test_iterator: torch.utils.data.DataLoader
|
||||
:param tensors: Dictionary of protein names to embeddings
|
||||
:type tensors: dict[str, torch.Tensor]
|
||||
:param use_cuda: Whether to use GPU
|
||||
:type use_cuda: bool
|
||||
|
||||
:return: (Loss, number correct, mean square error, precision, recall, F1 Score, AUPR)
|
||||
:rtype: (torch.Tensor, int, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor)
|
||||
"""
|
||||
p_hat = []
|
||||
true_y = []
|
||||
|
||||
for n0, n1, y in test_iterator:
|
||||
p_hat.append(predict_interaction(model, n0, n1, tensors, use_cuda))
|
||||
true_y.append(y)
|
||||
|
||||
y = torch.cat(true_y, 0)
|
||||
p_hat = torch.cat(p_hat, 0)
|
||||
|
||||
if use_cuda:
|
||||
y.cuda()
|
||||
p_hat = torch.Tensor([x.cuda() for x in p_hat])
|
||||
p_hat.cuda()
|
||||
|
||||
loss = F.binary_cross_entropy(p_hat.float(), y.float()).item()
|
||||
b = len(y)
|
||||
|
||||
with torch.no_grad():
|
||||
guess_cutoff = torch.Tensor([0.5]).float()
|
||||
p_hat = p_hat.float()
|
||||
y = y.float()
|
||||
p_guess = (guess_cutoff * torch.ones(b) < p_hat).float()
|
||||
correct = torch.sum(p_guess == y).item()
|
||||
mse = torch.mean((y.float() - p_hat) ** 2).item()
|
||||
|
||||
tp = torch.sum(y * p_hat).item()
|
||||
pr = tp / torch.sum(p_hat).item()
|
||||
re = tp / torch.sum(y).item()
|
||||
f1 = 2 * pr * re / (pr + re)
|
||||
|
||||
y = y.cpu().numpy()
|
||||
p_hat = p_hat.data.cpu().numpy()
|
||||
|
||||
aupr = average_precision(y, p_hat)
|
||||
|
||||
return loss, correct, mse, pr, re, f1, aupr
|
||||
|
||||
|
||||
def main(args):
|
||||
"""
|
||||
Run training from arguments.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
|
||||
output = args.outfile
|
||||
if output is None:
|
||||
output = sys.stdout
|
||||
else:
|
||||
output = open(output, "w")
|
||||
|
||||
print(f'# Called as: {" ".join(sys.argv)}', file=output)
|
||||
if output is not sys.stdout:
|
||||
print(f'Called as: {" ".join(sys.argv)}')
|
||||
|
||||
# Set device
|
||||
device = args.device
|
||||
use_cuda = (device >= 0) and torch.cuda.is_available()
|
||||
if use_cuda:
|
||||
torch.cuda.set_device(device)
|
||||
print(
|
||||
f"# Using CUDA device {device} - {torch.cuda.get_device_name(device)}",
|
||||
file=output,
|
||||
)
|
||||
else:
|
||||
print("# Using CPU", file=output)
|
||||
device = "cpu"
|
||||
|
||||
batch_size = args.batch_size
|
||||
|
||||
train_fi = args.train
|
||||
test_fi = args.val
|
||||
augment = args.augment
|
||||
embedding_h5 = args.embedding
|
||||
h5fi = h5py.File(embedding_h5, "r")
|
||||
|
||||
print(f"# Loading training pairs from {train_fi}...", file=output)
|
||||
output.flush()
|
||||
|
||||
train_df = pd.read_csv(train_fi, sep="\t", header=None)
|
||||
if augment:
|
||||
train_n0 = pd.concat((train_df[0], train_df[1]), axis=0).reset_index(
|
||||
drop=True
|
||||
)
|
||||
train_n1 = pd.concat((train_df[1], train_df[0]), axis=0).reset_index(
|
||||
drop=True
|
||||
)
|
||||
train_y = torch.from_numpy(
|
||||
pd.concat((train_df[2], train_df[2])).values
|
||||
)
|
||||
else:
|
||||
train_n0, train_n1 = train_df[0], train_df[1]
|
||||
train_y = torch.from_numpy(train_df[2].values)
|
||||
|
||||
print(f"# Loading testing pairs from {test_fi}...", file=output)
|
||||
output.flush()
|
||||
|
||||
test_df = pd.read_csv(test_fi, sep="\t", header=None)
|
||||
test_n0, test_n1 = test_df[0], test_df[1]
|
||||
test_y = torch.from_numpy(test_df[2].values)
|
||||
output.flush()
|
||||
|
||||
train_pairs = PairedDataset(train_n0, train_n1, train_y)
|
||||
pairs_train_iterator = torch.utils.data.DataLoader(
|
||||
train_pairs,
|
||||
batch_size=batch_size,
|
||||
collate_fn=collate_paired_sequences,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
test_pairs = PairedDataset(test_n0, test_n1, test_y)
|
||||
pairs_test_iterator = torch.utils.data.DataLoader(
|
||||
test_pairs,
|
||||
batch_size=batch_size,
|
||||
collate_fn=collate_paired_sequences,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
output.flush()
|
||||
|
||||
print(f"# Loading embeddings", file=output)
|
||||
tensors = {}
|
||||
all_proteins = (
|
||||
set(train_n0)
|
||||
.union(set(train_n1))
|
||||
.union(set(test_n0))
|
||||
.union(set(test_n1))
|
||||
)
|
||||
for prot_name in tqdm(all_proteins):
|
||||
tensors[prot_name] = torch.from_numpy(h5fi[prot_name][:, :])
|
||||
|
||||
use_cuda = (args.device > -1) and torch.cuda.is_available()
|
||||
|
||||
if args.checkpoint is None:
|
||||
|
||||
projection_dim = args.projection_dim
|
||||
dropout_p = args.dropout_p
|
||||
embedding = FullyConnectedEmbed(
|
||||
6165, projection_dim, dropout=dropout_p
|
||||
)
|
||||
print("# Initializing embedding model with:", file=output)
|
||||
print(f"\tprojection_dim: {projection_dim}", file=output)
|
||||
print(f"\tdropout_p: {dropout_p}", file=output)
|
||||
|
||||
# Create contact model
|
||||
hidden_dim = args.hidden_dim
|
||||
kernel_width = args.kernel_width
|
||||
print("# Initializing contact model with:", file=output)
|
||||
print(f"\thidden_dim: {hidden_dim}", file=output)
|
||||
print(f"\tkernel_width: {kernel_width}", file=output)
|
||||
|
||||
contact = ContactCNN(projection_dim, hidden_dim, kernel_width)
|
||||
|
||||
# Create the full model
|
||||
use_W = args.use_w
|
||||
pool_width = args.pool_width
|
||||
print("# Initializing interaction model with:", file=output)
|
||||
print(f"\tpool_width: {pool_width}", file=output)
|
||||
print(f"\tuse_w: {use_W}", file=output)
|
||||
model = ModelInteraction(
|
||||
embedding, contact, use_W=use_W, pool_size=pool_width
|
||||
)
|
||||
|
||||
print(model, file=output)
|
||||
|
||||
else:
|
||||
print(
|
||||
"# Loading model from checkpoint {}".format(args.checkpoint),
|
||||
file=output,
|
||||
)
|
||||
model = torch.load(args.checkpoint)
|
||||
model.use_cuda = use_cuda
|
||||
|
||||
if use_cuda:
|
||||
model = model.cuda()
|
||||
|
||||
# Train the model
|
||||
lr = args.lr
|
||||
wd = args.weight_decay
|
||||
num_epochs = args.num_epochs
|
||||
batch_size = args.batch_size
|
||||
report_steps = args.epoch_scale
|
||||
inter_weight = args.lambda_
|
||||
cmap_weight = 1 - inter_weight
|
||||
digits = int(np.floor(np.log10(num_epochs))) + 1
|
||||
save_prefix = args.save_prefix
|
||||
if save_prefix is None:
|
||||
save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M")
|
||||
|
||||
params = [p for p in model.parameters() if p.requires_grad]
|
||||
optim = torch.optim.Adam(params, lr=lr, weight_decay=wd)
|
||||
|
||||
print(f'# Using save prefix "{save_prefix}"', file=output)
|
||||
print(f"# Training with Adam: lr={lr}, weight_decay={wd}", file=output)
|
||||
print(f"\tnum_epochs: {num_epochs}", file=output)
|
||||
print(f"\tepoch_scale: {report_steps}", file=output)
|
||||
print(f"\tbatch_size: {batch_size}", file=output)
|
||||
print(f"\tinteraction weight: {inter_weight}", file=output)
|
||||
print(f"\tcontact map weight: {cmap_weight}", file=output)
|
||||
output.flush()
|
||||
|
||||
batch_report_fmt = (
|
||||
"# [{}/{}] training {:.1%}: Loss={:.6}, Accuracy={:.3%}, MSE={:.6}"
|
||||
)
|
||||
epoch_report_fmt = "# Finished Epoch {}/{}: Loss={:.6}, Accuracy={:.3%}, MSE={:.6}, Precision={:.6}, Recall={:.6}, F1={:.6}, AUPR={:.6}"
|
||||
|
||||
N = len(pairs_train_iterator) * batch_size
|
||||
for epoch in range(num_epochs):
|
||||
|
||||
model.train()
|
||||
|
||||
n = 0
|
||||
loss_accum = 0
|
||||
acc_accum = 0
|
||||
mse_accum = 0
|
||||
|
||||
# Train batches
|
||||
for (z0, z1, y) in tqdm(
|
||||
pairs_train_iterator,
|
||||
desc=f"Epoch {epoch+1}/{num_epochs}",
|
||||
total=len(pairs_train_iterator),
|
||||
):
|
||||
|
||||
loss, correct, mse, b = interaction_grad(
|
||||
model, z0, z1, y, tensors, use_cuda, weight=inter_weight
|
||||
)
|
||||
|
||||
n += b
|
||||
delta = b * (loss - loss_accum)
|
||||
loss_accum += delta / n
|
||||
|
||||
delta = correct - b * acc_accum
|
||||
acc_accum += delta / n
|
||||
|
||||
delta = b * (mse - mse_accum)
|
||||
mse_accum += delta / n
|
||||
|
||||
report = (n - b) // 100 < n // 100
|
||||
|
||||
optim.step()
|
||||
optim.zero_grad()
|
||||
model.clip()
|
||||
|
||||
if report:
|
||||
tokens = [
|
||||
epoch + 1,
|
||||
num_epochs,
|
||||
n / N,
|
||||
loss_accum,
|
||||
acc_accum,
|
||||
mse_accum,
|
||||
]
|
||||
if output is not sys.stdout:
|
||||
print(batch_report_fmt.format(*tokens), file=output)
|
||||
output.flush()
|
||||
|
||||
if (epoch + 1) % report_steps == 0:
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
(
|
||||
inter_loss,
|
||||
inter_correct,
|
||||
inter_mse,
|
||||
inter_pr,
|
||||
inter_re,
|
||||
inter_f1,
|
||||
inter_aupr,
|
||||
) = interaction_eval(
|
||||
model, pairs_test_iterator, tensors, use_cuda
|
||||
)
|
||||
tokens = [
|
||||
epoch + 1,
|
||||
num_epochs,
|
||||
inter_loss,
|
||||
inter_correct / (len(pairs_test_iterator) * batch_size),
|
||||
inter_mse,
|
||||
inter_pr,
|
||||
inter_re,
|
||||
inter_f1,
|
||||
inter_aupr,
|
||||
]
|
||||
print(epoch_report_fmt.format(*tokens), file=output)
|
||||
output.flush()
|
||||
|
||||
# Save the model
|
||||
if save_prefix is not None:
|
||||
save_path = (
|
||||
save_prefix
|
||||
+ "_epoch"
|
||||
+ str(epoch + 1).zfill(digits)
|
||||
+ ".sav"
|
||||
)
|
||||
print(f"# Saving model to {save_path}", file=output)
|
||||
model.cpu()
|
||||
torch.save(model, save_path)
|
||||
if use_cuda:
|
||||
model.cuda()
|
||||
|
||||
output.flush()
|
||||
|
||||
if save_prefix is not None:
|
||||
save_path = save_prefix + "_final.sav"
|
||||
print(f"# Saving final model to {save_path}", file=output)
|
||||
model.cpu()
|
||||
torch.save(model, save_path)
|
||||
if use_cuda:
|
||||
model.cuda()
|
||||
|
||||
output.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
add_args(parser)
|
||||
main(parser.parse_args())
|
||||
@@ -1,170 +0,0 @@
|
||||
import torch
|
||||
import torch.utils.data
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import subprocess as sp
|
||||
import sys
|
||||
import gzip as gz
|
||||
from datetime import datetime
|
||||
from .fasta import parse
|
||||
|
||||
|
||||
def log(msg, file=sys.stderr):
|
||||
"""
|
||||
Log datetime-stamped message to file
|
||||
|
||||
:param msg: Message to log
|
||||
:param f: Writable file object to log message to
|
||||
"""
|
||||
timestr = datetime.utcnow().isoformat(sep="-", timespec="milliseconds")
|
||||
file.write(f"[{timestr}] {msg}\n")
|
||||
file.flush()
|
||||
|
||||
|
||||
def plot_PR_curve(y, phat, saveFile=None):
|
||||
"""
|
||||
Plot precision-recall curve.
|
||||
|
||||
:param y: Labels
|
||||
:type y: np.ndarray
|
||||
:param phat: Predicted probabilities
|
||||
:type phat: np.ndarray
|
||||
:param saveFile: File for plot of curve to be saved to
|
||||
:type saveFile: str
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
from sklearn.metrics import precision_recall_curve, average_precision_score
|
||||
|
||||
aupr = average_precision_score(y, phat)
|
||||
precision, recall, _ = precision_recall_curve(y, phat)
|
||||
|
||||
plt.step(recall, precision, color="b", alpha=0.2, where="post")
|
||||
plt.fill_between(recall, precision, step="post", alpha=0.2, color="b")
|
||||
plt.xlabel("Recall")
|
||||
plt.ylabel("Precision")
|
||||
plt.ylim([0.0, 1.05])
|
||||
plt.xlim([0.0, 1.0])
|
||||
plt.title("Precision-Recall (AUPR: {:.3})".format(aupr))
|
||||
if saveFile:
|
||||
plt.savefig(saveFile)
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_ROC_curve(y, phat, saveFile=None):
|
||||
"""
|
||||
Plot receiver operating characteristic curve.
|
||||
|
||||
:param y: Labels
|
||||
:type y: np.ndarray
|
||||
:param phat: Predicted probabilities
|
||||
:type phat: np.ndarray
|
||||
:param saveFile: File for plot of curve to be saved to
|
||||
:type saveFile: str
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
from sklearn.metrics import roc_curve, roc_auc_score
|
||||
|
||||
auroc = roc_auc_score(y, phat)
|
||||
|
||||
fpr, tpr, roc_thresh = roc_curve(y, phat)
|
||||
print("AUROC:", auroc)
|
||||
|
||||
plt.step(fpr, tpr, color="b", alpha=0.2, where="post")
|
||||
plt.fill_between(fpr, tpr, step="post", alpha=0.2, color="b")
|
||||
plt.xlabel("FPR")
|
||||
plt.ylabel("TPR")
|
||||
plt.ylim([0.0, 1.05])
|
||||
plt.xlim([0.0, 1.0])
|
||||
plt.title("Receiver Operating Characteristic (AUROC: {:.3})".format(auroc))
|
||||
if saveFile:
|
||||
plt.savefig(saveFile)
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
|
||||
def RBF(D, sigma=None):
|
||||
"""
|
||||
Convert distance matrix into similarity matrix using Radial Basis Function (RBF) Kernel.
|
||||
|
||||
:math:`RBF(x,x') = \\exp{\\frac{-(x - x')^{2}}{2\\sigma^{2}}}`
|
||||
|
||||
:param D: Distance matrix
|
||||
:type D: np.ndarray
|
||||
:param sigma: Bandwith of RBF Kernel [default: :math:`\\sqrt{\\text{max}(D)}`]
|
||||
:type sigma: float
|
||||
:return: Similarity matrix
|
||||
:rtype: np.ndarray
|
||||
"""
|
||||
sigma = sigma or np.sqrt(np.max(D))
|
||||
return np.exp(-1 * (np.square(D) / (2 * sigma ** 2)))
|
||||
|
||||
|
||||
def gpu_mem(device):
|
||||
"""
|
||||
Get current memory usage for GPU.
|
||||
|
||||
:param device: GPU device number
|
||||
:type device: int
|
||||
:return: memory used, memory total
|
||||
:rtype: int, int
|
||||
"""
|
||||
result = sp.check_output(
|
||||
[
|
||||
"nvidia-smi",
|
||||
"--query-gpu=memory.used,memory.total",
|
||||
"--format=csv,nounits,noheader",
|
||||
"--id={}".format(device),
|
||||
],
|
||||
encoding="utf-8",
|
||||
)
|
||||
gpu_memory = [int(x) for x in result.strip().split(",")]
|
||||
return gpu_memory[0], gpu_memory[1]
|
||||
|
||||
|
||||
class PairedDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
Dataset to be used by the PyTorch data loader for pairs of sequences and their labels.
|
||||
|
||||
:param X0: List of first item in the pair
|
||||
:param X1: List of second item in the pair
|
||||
:param Y: List of labels
|
||||
"""
|
||||
|
||||
def __init__(self, X0, X1, Y):
|
||||
self.X0 = X0
|
||||
self.X1 = X1
|
||||
self.Y = Y
|
||||
assert len(X0) == len(X1), (
|
||||
"X0: "
|
||||
+ str(len(X0))
|
||||
+ " X1: "
|
||||
+ str(len(X1))
|
||||
+ " Y: "
|
||||
+ str(len(Y))
|
||||
)
|
||||
assert len(X0) == len(Y), (
|
||||
"X0: "
|
||||
+ str(len(X0))
|
||||
+ " X1: "
|
||||
+ str(len(X1))
|
||||
+ " Y: "
|
||||
+ str(len(Y))
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.X0)
|
||||
|
||||
def __getitem__(self, i):
|
||||
return self.X0[i], self.X1[i], self.Y[i]
|
||||
|
||||
|
||||
def collate_paired_sequences(args):
|
||||
"""
|
||||
Collate function for PyTorch data loader.
|
||||
"""
|
||||
x0 = [a[0] for a in args]
|
||||
x1 = [a[1] for a in args]
|
||||
y = [a[2] for a in args]
|
||||
return x0, x1, torch.stack(y, 0)
|
||||
@@ -58,45 +58,6 @@ class FullyConnectedEmbed(nn.Module):
|
||||
return t
|
||||
|
||||
|
||||
class LSTMEmbed(nn.Module):
|
||||
def __init__(self, nout, activation="ReLU", sparse=False, p=0.5):
|
||||
super(LSTMEmbed, self).__init__()
|
||||
self.activation = activation
|
||||
self.sparse = sparse
|
||||
self.p = p
|
||||
|
||||
self.embedding = SkipLSTM(21, nout, 1024, 3)
|
||||
|
||||
for param in self.embedding.parameters():
|
||||
param.requires_grad = False
|
||||
torch.nn.init.normal_(self.embedding.proj.weight)
|
||||
torch.nn.init.uniform_(self.embedding.proj.bias, 0, 0)
|
||||
self.embedding.proj.weight.requires_grad = True
|
||||
self.embedding.proj.bias.requires_grad = True
|
||||
|
||||
self.activationDict = nn.ModuleDict(
|
||||
{
|
||||
"None": IdentityEmbed(),
|
||||
"ReLU": nn.ReLU(),
|
||||
"Sigmoid": nn.Sigmoid(),
|
||||
}
|
||||
)
|
||||
self.dropout = nn.Dropout(p=self.p)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
t = self.embedding(x)
|
||||
if self.activation:
|
||||
t = self.activationDict[self.activation](t)
|
||||
if self.sparse:
|
||||
t = self.dropout(t)
|
||||
|
||||
return t
|
||||
|
||||
def long_embed(self, x):
|
||||
return self.embedding.transform(x)
|
||||
|
||||
|
||||
class SkipLSTM(nn.Module):
|
||||
"""
|
||||
Language model from `Bepler & Berger <https://github.com/tbepler/protein-sequence-embedding-iclr2019>`_.
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import logging as logg
|
||||
import os
|
||||
import os.path
|
||||
import sys
|
||||
from urllib.error import HTTPError
|
||||
from functools import wraps, partial
|
||||
|
||||
import torch
|
||||
|
||||
@@ -45,6 +47,16 @@ def build_human_1(state_dict_path):
|
||||
|
||||
VALID_MODELS = {"lm_v1": build_lm_1, "human_v1": build_human_1}
|
||||
|
||||
STATE_DICT_BASENAME = "dscript_{version}.pt"
|
||||
|
||||
|
||||
def get_state_dict_path(version: str) -> str:
|
||||
state_dict_basedir = os.path.dirname(os.path.realpath(__file__))
|
||||
state_dict_fullname = (
|
||||
f"{state_dict_basedir}/{STATE_DICT_BASENAME.format(version=version)}"
|
||||
)
|
||||
return state_dict_fullname
|
||||
|
||||
|
||||
def get_state_dict(version="human_v1", verbose=True):
|
||||
"""
|
||||
@@ -57,23 +69,57 @@ def get_state_dict(version="human_v1", verbose=True):
|
||||
:return: Path to state dictionary for pre-trained language model
|
||||
:rtype: str
|
||||
"""
|
||||
state_dict_basename = f"dscript_{version}.pt"
|
||||
state_dict_basedir = os.path.dirname(os.path.realpath(__file__))
|
||||
state_dict_fullname = f"{state_dict_basedir}/{state_dict_basename}"
|
||||
state_dict_url = (
|
||||
f"http://cb.csail.mit.edu/cb/dscript/data/models/{state_dict_basename}"
|
||||
)
|
||||
try:
|
||||
if verbose:
|
||||
logg.info(f"Downloading model {version} from {state_dict_url}...")
|
||||
get_local_or_download(state_dict_fullname, state_dict_url)
|
||||
except HTTPError as e:
|
||||
logg.error("Unable to download model - {}".format(e))
|
||||
sys.exit(1)
|
||||
state_dict_fullname = get_state_dict_path(version)
|
||||
state_dict_url = f"http://cb.csail.mit.edu/cb/dscript/data/models/{STATE_DICT_BASENAME.format(version=version)}"
|
||||
if not os.path.exists(state_dict_fullname):
|
||||
try:
|
||||
import shutil
|
||||
import urllib.request
|
||||
|
||||
if verbose:
|
||||
logg.info(
|
||||
f"Downloading model {version} from {state_dict_url}..."
|
||||
)
|
||||
with urllib.request.urlopen(state_dict_url) as response, open(
|
||||
state_dict_fullname, "wb"
|
||||
) as out_file:
|
||||
shutil.copyfileobj(response, out_file)
|
||||
except Exception as e:
|
||||
logg.info("Unable to download model - {}".format(e))
|
||||
sys.exit(1)
|
||||
return state_dict_fullname
|
||||
|
||||
|
||||
def get_pretrained(version="human_v1", verbose=True):
|
||||
def retry(retry_count: int):
|
||||
def decorate(func):
|
||||
@wraps(func)
|
||||
def retry_wrapper(*args, **kwargs):
|
||||
attempt = 0
|
||||
version = args[0]
|
||||
while attempt < retry_count:
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
except RuntimeError as e:
|
||||
logg.info(
|
||||
f"\033[93mLoading {version} from disk failed. Retrying download attempt: {attempt + 1}\033[0m"
|
||||
)
|
||||
if e.args[0].startswith("unexpected EOF"):
|
||||
state_dict_fullname = get_state_dict_path(version)
|
||||
if os.path.exists(state_dict_fullname):
|
||||
os.remove(state_dict_fullname)
|
||||
else:
|
||||
raise e
|
||||
attempt += 1
|
||||
raise Exception(f"Failed to download {version}")
|
||||
|
||||
return retry_wrapper
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
@retry(3)
|
||||
def get_pretrained(version="human_v1"):
|
||||
"""
|
||||
Get pre-trained model object.
|
||||
|
||||
@@ -95,5 +141,5 @@ def get_pretrained(version="human_v1", verbose=True):
|
||||
if version not in VALID_MODELS:
|
||||
raise ValueError("Model {} does not exist".format(version))
|
||||
|
||||
state_dict_path = get_state_dict(version, verbose=verbose)
|
||||
state_dict_path = get_state_dict(version)
|
||||
return VALID_MODELS[version](state_dict_path)
|
||||
|
||||
3
dscript/tests/test.tsv
Normal file
3
dscript/tests/test.tsv
Normal file
@@ -0,0 +1,3 @@
|
||||
seq1 seq2 1
|
||||
seq1 seq3 0
|
||||
seq2 seq3 1
|
||||
|
11
requirements.txt
Normal file
11
requirements.txt
Normal file
@@ -0,0 +1,11 @@
|
||||
torch==1.11
|
||||
biopython
|
||||
h5py
|
||||
matplotlib
|
||||
numpy
|
||||
pandas
|
||||
scikit-learn
|
||||
scipy
|
||||
seaborn
|
||||
setuptools
|
||||
tqdm
|
||||
Reference in New Issue
Block a user