mirror of
https://github.com/samsledje/D-SCRIPT.git
synced 2026-06-05 07:24:22 +08:00
172 lines
4.3 KiB
Python
172 lines
4.3 KiB
Python
import gzip as gz
|
|
import subprocess as sp
|
|
import sys
|
|
from datetime import datetime
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import torch
|
|
import torch.utils.data
|
|
|
|
from .fasta import parse
|
|
|
|
|
|
def log(msg, file=sys.stderr):
|
|
"""
|
|
Log datetime-stamped message to file
|
|
|
|
:param msg: Message to log
|
|
:param f: Writable file object to log message to
|
|
"""
|
|
timestr = datetime.utcnow().isoformat(sep="-", timespec="milliseconds")
|
|
file.write(f"[{timestr}] {msg}\n")
|
|
file.flush()
|
|
|
|
|
|
def plot_PR_curve(y, phat, saveFile=None):
|
|
"""
|
|
Plot precision-recall curve.
|
|
|
|
:param y: Labels
|
|
:type y: np.ndarray
|
|
:param phat: Predicted probabilities
|
|
:type phat: np.ndarray
|
|
:param saveFile: File for plot of curve to be saved to
|
|
:type saveFile: str
|
|
"""
|
|
import matplotlib.pyplot as plt
|
|
from sklearn.metrics import average_precision_score, precision_recall_curve
|
|
|
|
aupr = average_precision_score(y, phat)
|
|
precision, recall, _ = precision_recall_curve(y, phat)
|
|
|
|
plt.step(recall, precision, color="b", alpha=0.2, where="post")
|
|
plt.fill_between(recall, precision, step="post", alpha=0.2, color="b")
|
|
plt.xlabel("Recall")
|
|
plt.ylabel("Precision")
|
|
plt.ylim([0.0, 1.05])
|
|
plt.xlim([0.0, 1.0])
|
|
plt.title("Precision-Recall (AUPR: {:.3})".format(aupr))
|
|
if saveFile:
|
|
plt.savefig(saveFile)
|
|
else:
|
|
plt.show()
|
|
|
|
|
|
def plot_ROC_curve(y, phat, saveFile=None):
|
|
"""
|
|
Plot receiver operating characteristic curve.
|
|
|
|
:param y: Labels
|
|
:type y: np.ndarray
|
|
:param phat: Predicted probabilities
|
|
:type phat: np.ndarray
|
|
:param saveFile: File for plot of curve to be saved to
|
|
:type saveFile: str
|
|
"""
|
|
import matplotlib.pyplot as plt
|
|
from sklearn.metrics import roc_auc_score, roc_curve
|
|
|
|
auroc = roc_auc_score(y, phat)
|
|
|
|
fpr, tpr, roc_thresh = roc_curve(y, phat)
|
|
print("AUROC:", auroc)
|
|
|
|
plt.step(fpr, tpr, color="b", alpha=0.2, where="post")
|
|
plt.fill_between(fpr, tpr, step="post", alpha=0.2, color="b")
|
|
plt.xlabel("FPR")
|
|
plt.ylabel("TPR")
|
|
plt.ylim([0.0, 1.05])
|
|
plt.xlim([0.0, 1.0])
|
|
plt.title("Receiver Operating Characteristic (AUROC: {:.3})".format(auroc))
|
|
if saveFile:
|
|
plt.savefig(saveFile)
|
|
else:
|
|
plt.show()
|
|
|
|
|
|
def RBF(D, sigma=None):
|
|
"""
|
|
Convert distance matrix into similarity matrix using Radial Basis Function (RBF) Kernel.
|
|
|
|
:math:`RBF(x,x') = \\exp{\\frac{-(x - x')^{2}}{2\\sigma^{2}}}`
|
|
|
|
:param D: Distance matrix
|
|
:type D: np.ndarray
|
|
:param sigma: Bandwith of RBF Kernel [default: :math:`\\sqrt{\\text{max}(D)}`]
|
|
:type sigma: float
|
|
:return: Similarity matrix
|
|
:rtype: np.ndarray
|
|
"""
|
|
sigma = sigma or np.sqrt(np.max(D))
|
|
return np.exp(-1 * (np.square(D) / (2 * sigma ** 2)))
|
|
|
|
|
|
def gpu_mem(device):
|
|
"""
|
|
Get current memory usage for GPU.
|
|
|
|
:param device: GPU device number
|
|
:type device: int
|
|
:return: memory used, memory total
|
|
:rtype: int, int
|
|
"""
|
|
result = sp.check_output(
|
|
[
|
|
"nvidia-smi",
|
|
"--query-gpu=memory.used,memory.total",
|
|
"--format=csv,nounits,noheader",
|
|
"--id={}".format(device),
|
|
],
|
|
encoding="utf-8",
|
|
)
|
|
gpu_memory = [int(x) for x in result.strip().split(",")]
|
|
return gpu_memory[0], gpu_memory[1]
|
|
|
|
|
|
class PairedDataset(torch.utils.data.Dataset):
|
|
"""
|
|
Dataset to be used by the PyTorch data loader for pairs of sequences and their labels.
|
|
|
|
:param X0: List of first item in the pair
|
|
:param X1: List of second item in the pair
|
|
:param Y: List of labels
|
|
"""
|
|
|
|
def __init__(self, X0, X1, Y):
|
|
self.X0 = X0
|
|
self.X1 = X1
|
|
self.Y = Y
|
|
assert len(X0) == len(X1), (
|
|
"X0: "
|
|
+ str(len(X0))
|
|
+ " X1: "
|
|
+ str(len(X1))
|
|
+ " Y: "
|
|
+ str(len(Y))
|
|
)
|
|
assert len(X0) == len(Y), (
|
|
"X0: "
|
|
+ str(len(X0))
|
|
+ " X1: "
|
|
+ str(len(X1))
|
|
+ " Y: "
|
|
+ str(len(Y))
|
|
)
|
|
|
|
def __len__(self):
|
|
return len(self.X0)
|
|
|
|
def __getitem__(self, i):
|
|
return self.X0[i], self.X1[i], self.Y[i]
|
|
|
|
|
|
def collate_paired_sequences(args):
|
|
"""
|
|
Collate function for PyTorch data loader.
|
|
"""
|
|
x0 = [a[0] for a in args]
|
|
x1 = [a[1] for a in args]
|
|
y = [a[2] for a in args]
|
|
return x0, x1, torch.stack(y, 0)
|