Added support for parallel pipeline processing

This commit is contained in:
Jérôme Tubiana
2021-11-01 12:57:56 +02:00
parent 9364b8fc38
commit 3efc04f621

View File

@@ -2,6 +2,8 @@ import numpy as np
from utilities import io_utils
from utilities.paths import pipeline_folder,MSA_folder,structures_folder
from utilities.dataset_utils import align_labels
from multiprocessing import Pool
from functools import partial
import time
import os
try:
@@ -142,9 +144,10 @@ class Pipeline():
verbose= True,
fresh= False,
save = True,
permissive=True
permissive=True,
overwrite=False,
ncores = 1
):
location_processed_dataset = pipeline_folder + dataset_name + '_%s.data' % self.pipeline_name
found = False
@@ -187,7 +190,9 @@ class Pipeline():
structures_folder=structures_folder,
MSA_folder=MSA_folder,
verbose=verbose,
permissive=permissive
permissive=permissive,
overwrite=overwrite,
ncores=ncores
)
print('Processed dataset built... (t=%.f s)' % (time.time() - t))
if save:
@@ -207,7 +212,9 @@ class Pipeline():
structures_folder=structures_folder,
MSA_folder=MSA_folder,
verbose=True,
permissive=True
overwrite=False,
permissive=True,
ncores = 1
):
B = len(list_origins)
if (list_labels is not None):
@@ -218,77 +225,157 @@ class Pipeline():
if (list_resids is not None):
assert len(list_resids) == B
inputs = []
outputs = []
failed_samples = []
for b, origin in enumerate(list_origins):
if verbose:
print('Processing example %s (%s/%s)' % (origin, b, B))
try:
pdbfile, chain_ids = PDBio.getPDB(origin, biounit=biounit, structures_folder=structures_folder)
struct, chain_objs = PDBio.load_chains(file=pdbfile, chain_ids=chain_ids)
if ('PWM' in self.requirements) | ('conservation' in self.requirements):
sequences = [PDB_processing.process_chain(chain_obj)[0] for chain_obj in chain_objs]
output_alignments = [MSA_folder + 'MSA_' + '%s_%s_%s' % (
PDBio.parse_str(origin)[0].split('/')[-1].split('.')[0], chain_id[0], chain_id[1]) + '.fasta' for chain_id in chain_ids]
MSA_files = [sequence_utils.call_hhblits(sequence, output_alignment,overwrite=False) for sequence, output_alignment in
zip(sequences, output_alignments)]
if len(MSA_files) == 1:
MSA_files = MSA_files[0]
else:
MSA_files = None
if has_labels:
labels = list_labels[b]
pdb_resids = PDB_processing.get_PDB_indices(chain_objs, return_model=True, return_chain=True)
aligned_labels = align_labels(labels,
pdb_resids,
label_resids=list_resids[b] if list_resids is not None else None)
else:
aligned_labels = None
input, output = self.process_example(
chain_obj=chain_objs,
MSA_file=MSA_files,
labels=aligned_labels
)
inputs.append(input)
if has_labels:
outputs.append(output)
except Exception as e:
print('Failed to process example %s (%s/%s), Error: %s' %(origin,b,B,str(e) ) )
if permissive:
failed_samples.append(b)
continue
else:
raise ValueError('Failed in non permissive mode')
ninputs = len(inputs[0]) if isinstance(inputs[0],list) else 1
if self.padded:
if ninputs>1:
inputs = [np.stack([input[k] for input in inputs], axis=0)
for k in range(ninputs)]
else:
inputs = np.stack(inputs,axis=0)
if has_labels:
outputs = np.stack(outputs,axis=0)
has_resids = True
else:
if ninputs>1:
inputs = [np.array([input[k] for input in inputs])
for k in range(ninputs)]
has_resids = False
if ncores>1:
ncores = min(ncores,B)
pool = Pool(ncores)
batch_size = int(np.ceil(B/ncores))
batch_list_origins = [list_origins[k*batch_size: min( (k+1) * batch_size , B) ] for k in range(ncores)]
if has_labels:
batch_list_labels = [list_labels[k * batch_size: min((k + 1) * batch_size, B)] for k in range(ncores)]
else:
batch_list_labels = [None for k in range(ncores)]
if has_resids:
batch_list_resids = [list_resids[k * batch_size: min((k + 1) * batch_size, B)] for k in range(ncores)]
else:
batch_list_resids = [None for k in range(ncores)]
_build_and_process_dataset = partial(self.build_and_process_dataset,
biounit=biounit,
structures_folder=structures_folder,
MSA_folder=MSA_folder,
verbose=verbose,
overwrite=overwrite,
permissive=permissive,
ncores = 1)
batch_outputs = pool.starmap(_build_and_process_dataset,zip(batch_list_origins,batch_list_resids,batch_list_labels))
pool.close()
## Determine if input/output are list.
input_is_list = False
output_is_list = False
ninputs = 1
noutputs = 1
for ksuccess in range(ncores):
if batch_outputs[ksuccess] != []:
input_is_list = isinstance(batch_outputs[ksuccess][0],list)
ninputs = len(batch_outputs[ksuccess][0])
if has_labels:
output_is_list = isinstance(batch_outputs[ksuccess][1],list)
noutputs = len(batch_outputs[ksuccess][1])
break
if input_is_list:
inputs = [[] for _ in range(ninputs)]
for batch_output in batch_outputs:
if batch_output[0] != []:
for l in range(ninputs):
inputs[l] += list(batch_output[0][l])
for l in range(ninputs):
inputs[l] = np.array(inputs[l])
else:
inputs = []
for batch_output in batch_outputs:
if batch_output[0] != []:
inputs += list(batch_output[0])
inputs = np.array(inputs)
if has_labels:
outputs = np.array(outputs)
if has_labels:
return inputs, outputs,failed_samples
if output_is_list:
outputs = [[] for _ in range(noutputs)]
for batch_output in batch_outputs:
if batch_output[1] != []:
for l in range(noutputs):
outputs[l] += list(batch_output[1][l])
for l in range(noutputs):
outputs[l] = np.array(outputs[l])
else:
outputs = []
for batch_output in batch_outputs:
if batch_output[1] != []:
outputs += list(batch_output[1])
outputs = np.array(outputs)
else:
outputs = None
failed_samples = list(np.concatenate([np.array(batch_outputs[k][2],dtype=np.int)+k*batch_size for k in range(ncores)]))
return inputs,outputs,failed_samples
else:
return inputs,None,failed_samples
inputs = []
outputs = []
failed_samples = []
for b, origin in enumerate(list_origins):
if verbose:
print('Processing example %s (%s/%s)' % (origin, b, B))
try:
pdbfile, chain_ids = PDBio.getPDB(origin, biounit=biounit, structures_folder=structures_folder)
struct, chain_objs = PDBio.load_chains(file=pdbfile, chain_ids=chain_ids)
if ('PWM' in self.requirements) | ('conservation' in self.requirements):
sequences = [PDB_processing.process_chain(chain_obj)[0] for chain_obj in chain_objs]
output_alignments = [MSA_folder + 'MSA_' + '%s_%s_%s' % (
PDBio.parse_str(origin)[0].split('/')[-1].split('.')[0], chain_id[0], chain_id[1]) + '.fasta' for chain_id in chain_ids]
MSA_files = [sequence_utils.call_hhblits(sequence, output_alignment,overwrite=overwrite) for sequence, output_alignment in
zip(sequences, output_alignments)]
if len(MSA_files) == 1:
MSA_files = MSA_files[0]
else:
MSA_files = None
if has_labels:
labels = list_labels[b]
pdb_resids = PDB_processing.get_PDB_indices(chain_objs, return_model=True, return_chain=True)
aligned_labels = align_labels(labels,
pdb_resids,
label_resids=list_resids[b] if has_resids else None)
else:
aligned_labels = None
input, output = self.process_example(
chain_obj=chain_objs,
MSA_file=MSA_files,
labels=aligned_labels
)
inputs.append(input)
if has_labels:
outputs.append(output)
except Exception as e:
print('Failed to process example %s (%s/%s), Error: %s' %(origin,b,B,str(e) ) )
if permissive:
failed_samples.append(b)
continue
else:
raise ValueError('Failed in non permissive mode')
if len(inputs)==0:
# No successful run.
if has_labels:
return [],[],failed_samples
else:
return [],None,failed_samples
ninputs = len(inputs[0]) if isinstance(inputs[0],list) else 1
if self.padded:
if ninputs>1:
inputs = [np.stack([input[k] for input in inputs], axis=0)
for k in range(ninputs)]
else:
inputs = np.stack(inputs,axis=0)
if has_labels:
outputs = np.stack(outputs,axis=0)
else:
if ninputs>1:
inputs = [np.array([input[k] for input in inputs])
for k in range(ninputs)]
else:
inputs = np.array(inputs)
if has_labels:
outputs = np.array(outputs)
if has_labels:
return inputs, outputs,failed_samples
else:
return inputs,None,failed_samples
def process_dataset(self, env,label_name=None,permissive=True):
return (None,), (None,),None