diff --git a/preprocessing/sequence_utils.py b/preprocessing/sequence_utils.py index 97b38f0..f567ca8 100644 --- a/preprocessing/sequence_utils.py +++ b/preprocessing/sequence_utils.py @@ -50,9 +50,7 @@ def num2seq(num): return [''.join([aa[min(x, len(aa) - 1)] for x in num_seq]) for num_seq in num] -def load_FASTA(filename, with_labels=False, remove_insertions=True, drop_duplicates=True): - remove_insertions = True - with_labels = True +def load_FASTA(filename, with_labels=True, numerical=True,remove_insertions=True, drop_duplicates=True): count = 0 current_seq = '' all_seqs = [] @@ -74,9 +72,11 @@ def load_FASTA(filename, with_labels=False, remove_insertions=True, drop_duplica [x for x in current_seq if not (x.islower() | (x == '.'))]) all_seqs.append(current_seq) - all_seqs = np.array(list( - map(lambda x: [aadict[y] for y in x], all_seqs[1:])), dtype=curr_int, order="c") - + if numerical: + all_seqs = np.array(list( + map(lambda x: [aadict[y] for y in x], all_seqs[1:])), dtype=curr_int, order="c") + else: + all_seqs = np.array(all_seqs[1:]) if drop_duplicates: all_seqs = pd.DataFrame(all_seqs).drop_duplicates() if with_labels: