diff --git a/dscript/commands/embed.py b/dscript/commands/embed.py index fbdc6db..4a23e8a 100644 --- a/dscript/commands/embed.py +++ b/dscript/commands/embed.py @@ -29,8 +29,11 @@ def add_args(parser): parser.add_argument("--seqs", help="Sequences to be embedded", required=True) parser.add_argument("-o", "--outfile", help="h5 file to write results", required=True) parser.add_argument( - "-d", "--device", type=str, default="cpu", - help="Compute device to use. Options: 'cpu' or GPU index (0, 1, 2, etc.)." + "-d", + "--device", + type=str, + default="cpu", + help="Compute device to use. Options: 'cpu' or GPU index (0, 1, 2, etc.).", ) return parser @@ -45,11 +48,13 @@ def main(args): outPath = args.outfile device_arg = args.device if device_arg.lower() == "cpu": - device = -1 #Refers to CPU in embed_from_fasta + device = -1 # Refers to CPU in embed_from_fasta elif device_arg.isdigit(): # Allow only nonnegative integers device = int(device_arg) else: - log(f"Invalid device argument: {device_arg}. Use 'cpu' or a GPU index. Using CPU.") + log( + f"Invalid device argument: {device_arg}. Use 'cpu' or a GPU index. Using CPU." + ) device = -1 embed_from_fasta(inPath, outPath, device, verbose=True) diff --git a/dscript/commands/predict_serial.py b/dscript/commands/predict_serial.py index 3f25056..43740dd 100644 --- a/dscript/commands/predict_serial.py +++ b/dscript/commands/predict_serial.py @@ -62,8 +62,11 @@ def add_args(parser): ) parser.add_argument("-o", "--outfile", help="File for predictions") parser.add_argument( - "-d", "--device", type=str, default="cpu", - help="Compute device to use. Options: 'cpu' or GPU index (0, 1, 2, etc.)." + "-d", + "--device", + type=str, + default="cpu", + help="Compute device to use. Options: 'cpu' or GPU index (0, 1, 2, etc.).", ) parser.add_argument( "--store_cmaps", @@ -102,11 +105,13 @@ def main(args): embPath = args.embeddings device_arg = args.device if device_arg.lower() == "cpu": - device = -1 #Refers to CPU in embed_from_fasta + device = -1 # Refers to CPU in embed_from_fasta elif device_arg.isdigit(): # Allow only nonnegative integers device = int(device_arg) else: - log(f"Invalid device argument: {device_arg}. Use 'cpu' or a GPU index. Using CPU.") + log( + f"Invalid device argument: {device_arg}. Use 'cpu' or a GPU index. Using CPU." + ) device = -1 threshold = args.thresh @@ -120,7 +125,9 @@ def main(args): logFile = open(logFilePath, "w+") # Set Device - use_cuda = (device >= 0) and torch.cuda.is_available() and device < torch.cuda.device_count() + use_cuda = ( + (device >= 0) and torch.cuda.is_available() and device < torch.cuda.device_count() + ) if use_cuda: torch.cuda.set_device(device) log( diff --git a/dscript/language_model.py b/dscript/language_model.py index 7b4f69f..49b3dc5 100644 --- a/dscript/language_model.py +++ b/dscript/language_model.py @@ -54,7 +54,9 @@ def embed_from_fasta(fastaPath, outputPath, device=0, verbose=False): :param verbose: Print embedding progress :type verbose: bool """ - use_cuda = (device >= 0) and torch.cuda.is_available() and device < torch.cuda.device_count() + use_cuda = ( + (device >= 0) and torch.cuda.is_available() and device < torch.cuda.device_count() + ) if use_cuda: torch.cuda.set_device(device) if verbose: diff --git a/scripts/bmpi_bench/sample_pairs.py b/scripts/bmpi_bench/sample_pairs.py index 25ea31d..e0f3518 100644 --- a/scripts/bmpi_bench/sample_pairs.py +++ b/scripts/bmpi_bench/sample_pairs.py @@ -1,14 +1,13 @@ import random -import sys +import sys f = open(sys.argv[1]) prots = [x.strip() for x in f.readlines()] -pairs = [(p, q) for i, p in enumerate(prots) for q in prots[i+1:]] +pairs = [(p, q) for i, p in enumerate(prots) for q in prots[i + 1 :]] np = len(pairs) -target = int(float(sys.argv[2])*np) +target = int(float(sys.argv[2]) * np) print(f"Choosing {target} pairs from {np}", file=sys.stderr) random.seed(0) sel = random.sample(range(np), k=target) for i in sel: print(*pairs[i], sep="\t") - diff --git a/scripts/bmpi_bench/select_dmel_seqs.py b/scripts/bmpi_bench/select_dmel_seqs.py index c00ccee..ebb2b17 100644 --- a/scripts/bmpi_bench/select_dmel_seqs.py +++ b/scripts/bmpi_bench/select_dmel_seqs.py @@ -1,12 +1,12 @@ import sys -#args: isoform table, fasta, output table, output fasta, max length, output filtered fasta +# args: isoform table, fasta, output table, output fasta, max length, output filtered fasta prots = set() genes = set() with open(sys.argv[1]) as isoforms: for line in isoforms: - if line.isspace() or line[0] == "#" or not(line): + if line.isspace() or line[0] == "#" or not (line): continue tokens = line.strip().split() gene = tokens[0] @@ -29,16 +29,16 @@ else: with open(sys.argv[2]) as fasta: for line in fasta: - if line.isspace() or line[0] == "#" or not(line): + if line.isspace() or line[0] == "#" or not (line): continue if line[0] == ">": if collect: collect = False l = len(seqbuffer) print(id, name, l, sep="\t", file=outTable) - print(">"+id, seqbuffer, sep="\n", file=outFasta) + print(">" + id, seqbuffer, sep="\n", file=outFasta) if filter and l <= thresh: - print(">"+id, seqbuffer, sep="\n", file=outFilter) + print(">" + id, seqbuffer, sep="\n", file=outFilter) seqbuffer = "" id = "" name = "" @@ -57,15 +57,9 @@ with open(sys.argv[2]) as fasta: if collect: l = len(seqbuffer) print(id, name, l, sep="\t", file=outTable) - print(">"+id, seqbuffer, sep="\n", file=outFasta) + print(">" + id, seqbuffer, sep="\n", file=outFasta) if filter and l <= thresh: - print(">"+id, seqbuffer, sep="\n", file=outFilter) + print(">" + id, seqbuffer, sep="\n", file=outFilter) outTable.close() outFasta.close() outFilter.close() - - - - - - diff --git a/scripts/bmpi_bench/select_endo_seqs.py b/scripts/bmpi_bench/select_endo_seqs.py index 7700730..6e94e66 100644 --- a/scripts/bmpi_bench/select_endo_seqs.py +++ b/scripts/bmpi_bench/select_endo_seqs.py @@ -1,5 +1,5 @@ import sys -#args: fasta, min length, max length, output table, output filtered fasta +# args: fasta, min length, max length, output table, output filtered fasta collect = False @@ -11,7 +11,7 @@ thresh = int(sys.argv[3]) floor = int(sys.argv[2]) with open(sys.argv[1]) as fasta: for line in fasta: - if line.isspace() or line[0] == "#" or not(line): + if line.isspace() or line[0] == "#" or not (line): continue if line[0] == ">": if collect: @@ -19,7 +19,7 @@ with open(sys.argv[1]) as fasta: l = len(seqbuffer) print(name, l, sep="\t", file=outTable) if l >= floor and l <= thresh: - print(">"+name, seqbuffer, sep="\n", file=outFilter) + print(">" + name, seqbuffer, sep="\n", file=outFilter) seqbuffer = "" name = "" name = line.split()[0][1:] @@ -30,12 +30,6 @@ if collect: l = len(seqbuffer) print(name, l, sep="\t", file=outTable) if l >= floor and l <= thresh: - print(">"+name, seqbuffer, sep="\n", file=outFilter) + print(">" + name, seqbuffer, sep="\n", file=outFilter) outTable.close() outFilter.close() - - - - - -