mirror of
https://github.com/jertubiana/ScanNet.git
synced 2026-06-04 13:44:22 +08:00
Added support for parallel pipeline processing
This commit is contained in:
@@ -2,6 +2,8 @@ import numpy as np
|
||||
from utilities import io_utils
|
||||
from utilities.paths import pipeline_folder,MSA_folder,structures_folder
|
||||
from utilities.dataset_utils import align_labels
|
||||
from multiprocessing import Pool
|
||||
from functools import partial
|
||||
import time
|
||||
import os
|
||||
try:
|
||||
@@ -142,9 +144,10 @@ class Pipeline():
|
||||
verbose= True,
|
||||
fresh= False,
|
||||
save = True,
|
||||
permissive=True
|
||||
permissive=True,
|
||||
overwrite=False,
|
||||
ncores = 1
|
||||
):
|
||||
|
||||
location_processed_dataset = pipeline_folder + dataset_name + '_%s.data' % self.pipeline_name
|
||||
|
||||
found = False
|
||||
@@ -187,7 +190,9 @@ class Pipeline():
|
||||
structures_folder=structures_folder,
|
||||
MSA_folder=MSA_folder,
|
||||
verbose=verbose,
|
||||
permissive=permissive
|
||||
permissive=permissive,
|
||||
overwrite=overwrite,
|
||||
ncores=ncores
|
||||
)
|
||||
print('Processed dataset built... (t=%.f s)' % (time.time() - t))
|
||||
if save:
|
||||
@@ -207,7 +212,9 @@ class Pipeline():
|
||||
structures_folder=structures_folder,
|
||||
MSA_folder=MSA_folder,
|
||||
verbose=True,
|
||||
permissive=True
|
||||
overwrite=False,
|
||||
permissive=True,
|
||||
ncores = 1
|
||||
):
|
||||
B = len(list_origins)
|
||||
if (list_labels is not None):
|
||||
@@ -218,77 +225,157 @@ class Pipeline():
|
||||
|
||||
if (list_resids is not None):
|
||||
assert len(list_resids) == B
|
||||
|
||||
inputs = []
|
||||
outputs = []
|
||||
failed_samples = []
|
||||
for b, origin in enumerate(list_origins):
|
||||
if verbose:
|
||||
print('Processing example %s (%s/%s)' % (origin, b, B))
|
||||
|
||||
try:
|
||||
pdbfile, chain_ids = PDBio.getPDB(origin, biounit=biounit, structures_folder=structures_folder)
|
||||
struct, chain_objs = PDBio.load_chains(file=pdbfile, chain_ids=chain_ids)
|
||||
|
||||
if ('PWM' in self.requirements) | ('conservation' in self.requirements):
|
||||
sequences = [PDB_processing.process_chain(chain_obj)[0] for chain_obj in chain_objs]
|
||||
output_alignments = [MSA_folder + 'MSA_' + '%s_%s_%s' % (
|
||||
PDBio.parse_str(origin)[0].split('/')[-1].split('.')[0], chain_id[0], chain_id[1]) + '.fasta' for chain_id in chain_ids]
|
||||
MSA_files = [sequence_utils.call_hhblits(sequence, output_alignment,overwrite=False) for sequence, output_alignment in
|
||||
zip(sequences, output_alignments)]
|
||||
if len(MSA_files) == 1:
|
||||
MSA_files = MSA_files[0]
|
||||
else:
|
||||
MSA_files = None
|
||||
|
||||
if has_labels:
|
||||
labels = list_labels[b]
|
||||
pdb_resids = PDB_processing.get_PDB_indices(chain_objs, return_model=True, return_chain=True)
|
||||
aligned_labels = align_labels(labels,
|
||||
pdb_resids,
|
||||
label_resids=list_resids[b] if list_resids is not None else None)
|
||||
else:
|
||||
aligned_labels = None
|
||||
|
||||
input, output = self.process_example(
|
||||
chain_obj=chain_objs,
|
||||
MSA_file=MSA_files,
|
||||
labels=aligned_labels
|
||||
)
|
||||
|
||||
inputs.append(input)
|
||||
if has_labels:
|
||||
outputs.append(output)
|
||||
except Exception as e:
|
||||
print('Failed to process example %s (%s/%s), Error: %s' %(origin,b,B,str(e) ) )
|
||||
if permissive:
|
||||
failed_samples.append(b)
|
||||
continue
|
||||
else:
|
||||
raise ValueError('Failed in non permissive mode')
|
||||
|
||||
|
||||
ninputs = len(inputs[0]) if isinstance(inputs[0],list) else 1
|
||||
if self.padded:
|
||||
if ninputs>1:
|
||||
inputs = [np.stack([input[k] for input in inputs], axis=0)
|
||||
for k in range(ninputs)]
|
||||
else:
|
||||
inputs = np.stack(inputs,axis=0)
|
||||
if has_labels:
|
||||
outputs = np.stack(outputs,axis=0)
|
||||
has_resids = True
|
||||
else:
|
||||
if ninputs>1:
|
||||
inputs = [np.array([input[k] for input in inputs])
|
||||
for k in range(ninputs)]
|
||||
has_resids = False
|
||||
|
||||
if ncores>1:
|
||||
ncores = min(ncores,B)
|
||||
pool = Pool(ncores)
|
||||
batch_size = int(np.ceil(B/ncores))
|
||||
batch_list_origins = [list_origins[k*batch_size: min( (k+1) * batch_size , B) ] for k in range(ncores)]
|
||||
if has_labels:
|
||||
batch_list_labels = [list_labels[k * batch_size: min((k + 1) * batch_size, B)] for k in range(ncores)]
|
||||
else:
|
||||
batch_list_labels = [None for k in range(ncores)]
|
||||
if has_resids:
|
||||
batch_list_resids = [list_resids[k * batch_size: min((k + 1) * batch_size, B)] for k in range(ncores)]
|
||||
else:
|
||||
batch_list_resids = [None for k in range(ncores)]
|
||||
_build_and_process_dataset = partial(self.build_and_process_dataset,
|
||||
biounit=biounit,
|
||||
structures_folder=structures_folder,
|
||||
MSA_folder=MSA_folder,
|
||||
verbose=verbose,
|
||||
overwrite=overwrite,
|
||||
permissive=permissive,
|
||||
ncores = 1)
|
||||
batch_outputs = pool.starmap(_build_and_process_dataset,zip(batch_list_origins,batch_list_resids,batch_list_labels))
|
||||
pool.close()
|
||||
## Determine if input/output are list.
|
||||
input_is_list = False
|
||||
output_is_list = False
|
||||
ninputs = 1
|
||||
noutputs = 1
|
||||
for ksuccess in range(ncores):
|
||||
if batch_outputs[ksuccess] != []:
|
||||
input_is_list = isinstance(batch_outputs[ksuccess][0],list)
|
||||
ninputs = len(batch_outputs[ksuccess][0])
|
||||
if has_labels:
|
||||
output_is_list = isinstance(batch_outputs[ksuccess][1],list)
|
||||
noutputs = len(batch_outputs[ksuccess][1])
|
||||
break
|
||||
|
||||
if input_is_list:
|
||||
inputs = [[] for _ in range(ninputs)]
|
||||
for batch_output in batch_outputs:
|
||||
if batch_output[0] != []:
|
||||
for l in range(ninputs):
|
||||
inputs[l] += list(batch_output[0][l])
|
||||
for l in range(ninputs):
|
||||
inputs[l] = np.array(inputs[l])
|
||||
else:
|
||||
inputs = []
|
||||
for batch_output in batch_outputs:
|
||||
if batch_output[0] != []:
|
||||
inputs += list(batch_output[0])
|
||||
inputs = np.array(inputs)
|
||||
if has_labels:
|
||||
outputs = np.array(outputs)
|
||||
if has_labels:
|
||||
return inputs, outputs,failed_samples
|
||||
if output_is_list:
|
||||
outputs = [[] for _ in range(noutputs)]
|
||||
for batch_output in batch_outputs:
|
||||
if batch_output[1] != []:
|
||||
for l in range(noutputs):
|
||||
outputs[l] += list(batch_output[1][l])
|
||||
for l in range(noutputs):
|
||||
outputs[l] = np.array(outputs[l])
|
||||
else:
|
||||
outputs = []
|
||||
for batch_output in batch_outputs:
|
||||
if batch_output[1] != []:
|
||||
outputs += list(batch_output[1])
|
||||
outputs = np.array(outputs)
|
||||
else:
|
||||
outputs = None
|
||||
failed_samples = list(np.concatenate([np.array(batch_outputs[k][2],dtype=np.int)+k*batch_size for k in range(ncores)]))
|
||||
return inputs,outputs,failed_samples
|
||||
else:
|
||||
return inputs,None,failed_samples
|
||||
inputs = []
|
||||
outputs = []
|
||||
failed_samples = []
|
||||
for b, origin in enumerate(list_origins):
|
||||
if verbose:
|
||||
print('Processing example %s (%s/%s)' % (origin, b, B))
|
||||
|
||||
try:
|
||||
pdbfile, chain_ids = PDBio.getPDB(origin, biounit=biounit, structures_folder=structures_folder)
|
||||
struct, chain_objs = PDBio.load_chains(file=pdbfile, chain_ids=chain_ids)
|
||||
|
||||
if ('PWM' in self.requirements) | ('conservation' in self.requirements):
|
||||
sequences = [PDB_processing.process_chain(chain_obj)[0] for chain_obj in chain_objs]
|
||||
output_alignments = [MSA_folder + 'MSA_' + '%s_%s_%s' % (
|
||||
PDBio.parse_str(origin)[0].split('/')[-1].split('.')[0], chain_id[0], chain_id[1]) + '.fasta' for chain_id in chain_ids]
|
||||
MSA_files = [sequence_utils.call_hhblits(sequence, output_alignment,overwrite=overwrite) for sequence, output_alignment in
|
||||
zip(sequences, output_alignments)]
|
||||
if len(MSA_files) == 1:
|
||||
MSA_files = MSA_files[0]
|
||||
else:
|
||||
MSA_files = None
|
||||
|
||||
if has_labels:
|
||||
labels = list_labels[b]
|
||||
pdb_resids = PDB_processing.get_PDB_indices(chain_objs, return_model=True, return_chain=True)
|
||||
aligned_labels = align_labels(labels,
|
||||
pdb_resids,
|
||||
label_resids=list_resids[b] if has_resids else None)
|
||||
else:
|
||||
aligned_labels = None
|
||||
|
||||
input, output = self.process_example(
|
||||
chain_obj=chain_objs,
|
||||
MSA_file=MSA_files,
|
||||
labels=aligned_labels
|
||||
)
|
||||
|
||||
inputs.append(input)
|
||||
if has_labels:
|
||||
outputs.append(output)
|
||||
except Exception as e:
|
||||
print('Failed to process example %s (%s/%s), Error: %s' %(origin,b,B,str(e) ) )
|
||||
if permissive:
|
||||
failed_samples.append(b)
|
||||
continue
|
||||
else:
|
||||
raise ValueError('Failed in non permissive mode')
|
||||
|
||||
if len(inputs)==0:
|
||||
# No successful run.
|
||||
if has_labels:
|
||||
return [],[],failed_samples
|
||||
else:
|
||||
return [],None,failed_samples
|
||||
ninputs = len(inputs[0]) if isinstance(inputs[0],list) else 1
|
||||
|
||||
if self.padded:
|
||||
if ninputs>1:
|
||||
inputs = [np.stack([input[k] for input in inputs], axis=0)
|
||||
for k in range(ninputs)]
|
||||
else:
|
||||
inputs = np.stack(inputs,axis=0)
|
||||
if has_labels:
|
||||
outputs = np.stack(outputs,axis=0)
|
||||
else:
|
||||
if ninputs>1:
|
||||
inputs = [np.array([input[k] for input in inputs])
|
||||
for k in range(ninputs)]
|
||||
else:
|
||||
inputs = np.array(inputs)
|
||||
if has_labels:
|
||||
outputs = np.array(outputs)
|
||||
if has_labels:
|
||||
return inputs, outputs,failed_samples
|
||||
else:
|
||||
return inputs,None,failed_samples
|
||||
|
||||
def process_dataset(self, env,label_name=None,permissive=True):
|
||||
return (None,), (None,),None
|
||||
|
||||
Reference in New Issue
Block a user