mirror of
https://github.com/samsledje/D-SCRIPT.git
synced 2026-06-04 15:04:24 +08:00
merge
This commit is contained in:
@@ -5,6 +5,7 @@ Train a new model.
|
||||
import argparse
|
||||
import datetime
|
||||
import gzip as gz
|
||||
import os
|
||||
import subprocess as sp
|
||||
import sys
|
||||
|
||||
@@ -17,10 +18,8 @@ import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from sklearn.metrics import average_precision_score as average_precision
|
||||
from torch.autograd import Variable
|
||||
from torch.utils.data import DataLoader, IterableDataset
|
||||
from tqdm import tqdm
|
||||
|
||||
import dscript
|
||||
from dscript.models.contact import ContactCNN
|
||||
from dscript.models.embedding import FullyConnectedEmbed, IdentityEmbed
|
||||
from dscript.models.interaction import ModelInteraction
|
||||
@@ -343,7 +342,7 @@ def main(args):
|
||||
else:
|
||||
output = open(output, "w")
|
||||
|
||||
print(f'# Called as: {" ".join(sys.argv)}', file=output)
|
||||
print(f'Called as: {" ".join(sys.argv)}', file=output)
|
||||
if output is not sys.stdout:
|
||||
print(f'Called as: {" ".join(sys.argv)}')
|
||||
|
||||
@@ -351,6 +350,7 @@ def main(args):
|
||||
device = args.device
|
||||
use_cuda = (device >= 0) and torch.cuda.is_available()
|
||||
if use_cuda:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = device
|
||||
torch.cuda.set_device(device)
|
||||
print(
|
||||
f"# Using CUDA device {device} - {torch.cuda.get_device_name(device)}",
|
||||
|
||||
Reference in New Issue
Block a user