diff --git a/predict_bindingsites.py b/predict_bindingsites.py index 7ff5236..ae07725 100644 --- a/predict_bindingsites.py +++ b/predict_bindingsites.py @@ -661,9 +661,16 @@ def predict_interface_residues( chimera.annotate_pdb_file(pdb_file_locations[i], csv_file, annotated_pdb_file, output_script=True, mini=mini, maxi=maxi) if output_format == 'dictionary': - query_dictionary_predictions = [PDB_processing.make_values_dictionary(query_residue_id,query_prediction) for query_residue_id,query_prediction in zip(query_residue_ids,query_predictions)] + if ((isinstance(layer, list)) | (isinstance(layer, tuple)) | (not aggregate_models)): + query_dictionary_predictions = [PDB_processing.make_values_dictionary(query_residue_ids[k], [query_predictions[l][k] for l in range(len(query_predictions))]) + for k in range(len(query_residue_ids))] + else: + query_dictionary_predictions = [PDB_processing.make_values_dictionary(query_residue_id,query_prediction) for query_residue_id,query_prediction in zip(query_residue_ids,query_predictions)] return query_pdbs,query_names,query_dictionary_predictions else: + if ((isinstance(layer, list)) | (isinstance(layer, tuple)) | (not aggregate_models)): + query_predictions = [ + [query_predictions[i][j] for i in range(len(query_predictions))] for j in range(len(query_predictions[0]))] return query_pdbs,query_names,query_predictions, query_residue_ids, query_sequences diff --git a/predict_features.py b/predict_features.py index 8f3befe..fcd8a9e 100644 --- a/predict_features.py +++ b/predict_features.py @@ -63,6 +63,7 @@ def predict_features(list_queries,layer='SCAN_filter_activity_aa', ) if output_format == 'numpy': query_pdbs, query_names, query_features, query_residue_ids, query_sequences = query_outputs + if return_one: query_pdbs = query_pdbs[0] query_names = query_names[0] @@ -106,7 +107,7 @@ if __name__ == '__main__': if output_format == 'dictionary': - list_names, list_dictionary_features = predict_features(['1a3x_A','1brs_A'],layer=layer,model=model,output_format=output_format) + list_names, list_dictionary_features = predict_features(['1a3x_A','1brs_A'],layer=layer,model=model,output_format=output_format,permissive=True) print('Dictionary format: Dictionary with residue ids as key and features as items.') for k in range(2): print('Query',list_names[k]) @@ -117,7 +118,7 @@ if __name__ == '__main__': else: print('AA',key, 'Features:',item[:5],'Feature shape',item.shape) elif output_format == 'numpy': - list_names, list_features, list_residue_ids = predict_features(['1a3x_A','1brs_A'],layer=layer,model=model,output_format='numpy') + list_names,list_features, list_residue_ids = predict_features(['1a3x_A','1brs_A'],layer=layer,model=model,output_format='numpy',permissive=True) print('Numpy format: Numpy arrays with residue ids as key and features as items.') for k in range(2): print('Query',list_names[k])