diff --git a/preprocessing/pipelines.py b/preprocessing/pipelines.py index 79ee53c..64d6fff 100644 --- a/preprocessing/pipelines.py +++ b/preprocessing/pipelines.py @@ -2,6 +2,8 @@ import numpy as np from utilities import io_utils from utilities.paths import pipeline_folder,MSA_folder,structures_folder from utilities.dataset_utils import align_labels +from multiprocessing import Pool +from functools import partial import time import os try: @@ -142,9 +144,10 @@ class Pipeline(): verbose= True, fresh= False, save = True, - permissive=True + permissive=True, + overwrite=False, + ncores = 1 ): - location_processed_dataset = pipeline_folder + dataset_name + '_%s.data' % self.pipeline_name found = False @@ -187,7 +190,9 @@ class Pipeline(): structures_folder=structures_folder, MSA_folder=MSA_folder, verbose=verbose, - permissive=permissive + permissive=permissive, + overwrite=overwrite, + ncores=ncores ) print('Processed dataset built... (t=%.f s)' % (time.time() - t)) if save: @@ -207,7 +212,9 @@ class Pipeline(): structures_folder=structures_folder, MSA_folder=MSA_folder, verbose=True, - permissive=True + overwrite=False, + permissive=True, + ncores = 1 ): B = len(list_origins) if (list_labels is not None): @@ -218,77 +225,157 @@ class Pipeline(): if (list_resids is not None): assert len(list_resids) == B - - inputs = [] - outputs = [] - failed_samples = [] - for b, origin in enumerate(list_origins): - if verbose: - print('Processing example %s (%s/%s)' % (origin, b, B)) - - try: - pdbfile, chain_ids = PDBio.getPDB(origin, biounit=biounit, structures_folder=structures_folder) - struct, chain_objs = PDBio.load_chains(file=pdbfile, chain_ids=chain_ids) - - if ('PWM' in self.requirements) | ('conservation' in self.requirements): - sequences = [PDB_processing.process_chain(chain_obj)[0] for chain_obj in chain_objs] - output_alignments = [MSA_folder + 'MSA_' + '%s_%s_%s' % ( - PDBio.parse_str(origin)[0].split('/')[-1].split('.')[0], chain_id[0], chain_id[1]) + '.fasta' for chain_id in chain_ids] - MSA_files = [sequence_utils.call_hhblits(sequence, output_alignment,overwrite=False) for sequence, output_alignment in - zip(sequences, output_alignments)] - if len(MSA_files) == 1: - MSA_files = MSA_files[0] - else: - MSA_files = None - - if has_labels: - labels = list_labels[b] - pdb_resids = PDB_processing.get_PDB_indices(chain_objs, return_model=True, return_chain=True) - aligned_labels = align_labels(labels, - pdb_resids, - label_resids=list_resids[b] if list_resids is not None else None) - else: - aligned_labels = None - - input, output = self.process_example( - chain_obj=chain_objs, - MSA_file=MSA_files, - labels=aligned_labels - ) - - inputs.append(input) - if has_labels: - outputs.append(output) - except Exception as e: - print('Failed to process example %s (%s/%s), Error: %s' %(origin,b,B,str(e) ) ) - if permissive: - failed_samples.append(b) - continue - else: - raise ValueError('Failed in non permissive mode') - - - ninputs = len(inputs[0]) if isinstance(inputs[0],list) else 1 - if self.padded: - if ninputs>1: - inputs = [np.stack([input[k] for input in inputs], axis=0) - for k in range(ninputs)] - else: - inputs = np.stack(inputs,axis=0) - if has_labels: - outputs = np.stack(outputs,axis=0) + has_resids = True else: - if ninputs>1: - inputs = [np.array([input[k] for input in inputs]) - for k in range(ninputs)] + has_resids = False + + if ncores>1: + ncores = min(ncores,B) + pool = Pool(ncores) + batch_size = int(np.ceil(B/ncores)) + batch_list_origins = [list_origins[k*batch_size: min( (k+1) * batch_size , B) ] for k in range(ncores)] + if has_labels: + batch_list_labels = [list_labels[k * batch_size: min((k + 1) * batch_size, B)] for k in range(ncores)] else: + batch_list_labels = [None for k in range(ncores)] + if has_resids: + batch_list_resids = [list_resids[k * batch_size: min((k + 1) * batch_size, B)] for k in range(ncores)] + else: + batch_list_resids = [None for k in range(ncores)] + _build_and_process_dataset = partial(self.build_and_process_dataset, + biounit=biounit, + structures_folder=structures_folder, + MSA_folder=MSA_folder, + verbose=verbose, + overwrite=overwrite, + permissive=permissive, + ncores = 1) + batch_outputs = pool.starmap(_build_and_process_dataset,zip(batch_list_origins,batch_list_resids,batch_list_labels)) + pool.close() + ## Determine if input/output are list. + input_is_list = False + output_is_list = False + ninputs = 1 + noutputs = 1 + for ksuccess in range(ncores): + if batch_outputs[ksuccess] != []: + input_is_list = isinstance(batch_outputs[ksuccess][0],list) + ninputs = len(batch_outputs[ksuccess][0]) + if has_labels: + output_is_list = isinstance(batch_outputs[ksuccess][1],list) + noutputs = len(batch_outputs[ksuccess][1]) + break + + if input_is_list: + inputs = [[] for _ in range(ninputs)] + for batch_output in batch_outputs: + if batch_output[0] != []: + for l in range(ninputs): + inputs[l] += list(batch_output[0][l]) + for l in range(ninputs): + inputs[l] = np.array(inputs[l]) + else: + inputs = [] + for batch_output in batch_outputs: + if batch_output[0] != []: + inputs += list(batch_output[0]) inputs = np.array(inputs) if has_labels: - outputs = np.array(outputs) - if has_labels: - return inputs, outputs,failed_samples + if output_is_list: + outputs = [[] for _ in range(noutputs)] + for batch_output in batch_outputs: + if batch_output[1] != []: + for l in range(noutputs): + outputs[l] += list(batch_output[1][l]) + for l in range(noutputs): + outputs[l] = np.array(outputs[l]) + else: + outputs = [] + for batch_output in batch_outputs: + if batch_output[1] != []: + outputs += list(batch_output[1]) + outputs = np.array(outputs) + else: + outputs = None + failed_samples = list(np.concatenate([np.array(batch_outputs[k][2],dtype=np.int)+k*batch_size for k in range(ncores)])) + return inputs,outputs,failed_samples else: - return inputs,None,failed_samples + inputs = [] + outputs = [] + failed_samples = [] + for b, origin in enumerate(list_origins): + if verbose: + print('Processing example %s (%s/%s)' % (origin, b, B)) + + try: + pdbfile, chain_ids = PDBio.getPDB(origin, biounit=biounit, structures_folder=structures_folder) + struct, chain_objs = PDBio.load_chains(file=pdbfile, chain_ids=chain_ids) + + if ('PWM' in self.requirements) | ('conservation' in self.requirements): + sequences = [PDB_processing.process_chain(chain_obj)[0] for chain_obj in chain_objs] + output_alignments = [MSA_folder + 'MSA_' + '%s_%s_%s' % ( + PDBio.parse_str(origin)[0].split('/')[-1].split('.')[0], chain_id[0], chain_id[1]) + '.fasta' for chain_id in chain_ids] + MSA_files = [sequence_utils.call_hhblits(sequence, output_alignment,overwrite=overwrite) for sequence, output_alignment in + zip(sequences, output_alignments)] + if len(MSA_files) == 1: + MSA_files = MSA_files[0] + else: + MSA_files = None + + if has_labels: + labels = list_labels[b] + pdb_resids = PDB_processing.get_PDB_indices(chain_objs, return_model=True, return_chain=True) + aligned_labels = align_labels(labels, + pdb_resids, + label_resids=list_resids[b] if has_resids else None) + else: + aligned_labels = None + + input, output = self.process_example( + chain_obj=chain_objs, + MSA_file=MSA_files, + labels=aligned_labels + ) + + inputs.append(input) + if has_labels: + outputs.append(output) + except Exception as e: + print('Failed to process example %s (%s/%s), Error: %s' %(origin,b,B,str(e) ) ) + if permissive: + failed_samples.append(b) + continue + else: + raise ValueError('Failed in non permissive mode') + + if len(inputs)==0: + # No successful run. + if has_labels: + return [],[],failed_samples + else: + return [],None,failed_samples + ninputs = len(inputs[0]) if isinstance(inputs[0],list) else 1 + + if self.padded: + if ninputs>1: + inputs = [np.stack([input[k] for input in inputs], axis=0) + for k in range(ninputs)] + else: + inputs = np.stack(inputs,axis=0) + if has_labels: + outputs = np.stack(outputs,axis=0) + else: + if ninputs>1: + inputs = [np.array([input[k] for input in inputs]) + for k in range(ninputs)] + else: + inputs = np.array(inputs) + if has_labels: + outputs = np.array(outputs) + if has_labels: + return inputs, outputs,failed_samples + else: + return inputs,None,failed_samples def process_dataset(self, env,label_name=None,permissive=True): return (None,), (None,),None