diff --git a/baselines/train_handcrafted_features_PPBS.py b/baselines/train_handcrafted_features_PPBS.py index b465b2c..01615d2 100644 --- a/baselines/train_handcrafted_features_PPBS.py +++ b/baselines/train_handcrafted_features_PPBS.py @@ -18,7 +18,7 @@ def make_PR_curves( title = '', figsize=(10, 10), margin=0.05,grid=0.1 - ,fs=25): + ,fs=25,legend_fs=15): import matplotlib matplotlib.use('agg') import matplotlib.pyplot as plt @@ -53,14 +53,14 @@ def make_PR_curves( fig, ax = plt.subplots(figsize=figsize) for i in range(nSubsets): ax.plot(all_PR_curves[i][1], all_PR_curves[i][0], color=subsetColors[i],linewidth=2.0, - label='%s (AUCPR= %.3f)' % (subset_names[i], all_AUCPRs[i])) + label='%s (%.3f)' % (subset_names[i], all_AUCPRs[i])) plt.xticks(np.arange(0, 1.0 + grid, grid), fontsize=fs * 2/3) plt.yticks(np.arange(0, 1.0 + grid, grid), fontsize=fs * 2/3) plt.xlim([0 - margin, 1 + margin]) plt.ylim([0 - margin, 1 + margin]) plt.grid() - plt.legend(fontsize=fs) + plt.legend(fontsize=legend_fs) plt.xlabel('Recall', fontsize=fs) plt.ylabel('Precision', fontsize=fs) plt.title(title,fontsize=fs) diff --git a/preprocessing/protein_frames.py b/preprocessing/protein_frames.py index 06b7c5f..0973773 100644 --- a/preprocessing/protein_frames.py +++ b/preprocessing/protein_frames.py @@ -305,16 +305,20 @@ def _get_aa_frameCloud_quadruplet(atom_coordinates, atom_ids, verbose=True): def add_virtual_atoms(atom_clouds, atom_triplets, verbose=True): virtual_atom_clouds, atom_triplets = _add_virtual_atoms(atom_clouds, atom_triplets, verbose=verbose) - if np.abs(virtual_atom_clouds).max() >1e8: - print('The weird numba bug happened again at add_virtual_atoms, rerunning once') - virtual_atom_clouds, atom_triplets = _add_virtual_atoms(atom_clouds, atom_triplets, verbose=verbose) - if np.abs(virtual_atom_clouds).max() > 1e8: - print('The weird numba bug persists...') - else: - print('The weird numba bug was fixed by rerunning') - - if len(virtual_atom_clouds) > 0: + virtual_atom_clouds = np.array(virtual_atom_clouds) + if np.abs(virtual_atom_clouds).max() >1e8: + print('The weird numba bug happened again at add_virtual_atoms, need to fix virtual atoms') + weird_indices = np.nonzero(np.abs(virtual_atom_clouds).max(-1) >1e8 )[0] + print('Fixing %s virtual atoms'%len(weird_indices)) + original_atom_indices = np.array([np.nonzero((atom_triplets[:,1:] == len(atom_triplets)+ index).max(-1))[0][0] for index in weird_indices]) + print(weird_indices,original_atom_indices) + for weird_index, original_atom_index in zip(weird_indices,original_atom_indices): + virtual_atom_clouds[weird_index] = atom_clouds[original_atom_index,:] + if atom_triplets[original_atom_index,1] == weird_index: + virtual_atom_clouds[weird_index][0] +=1 + else: + virtual_atom_clouds[weird_index][2] += 1 atom_clouds = np.concatenate([atom_clouds, np.array(virtual_atom_clouds)], axis=0) return atom_clouds, atom_triplets @@ -377,8 +381,8 @@ def _add_virtual_atoms(atom_clouds, atom_triplets, verbose=True): if __name__ == '__main__': - import PDB_processing import Bio.PDB + from preprocessing import PDBio,PDB_processing PDB_folder = '/Users/jerometubiana/PDB/' pdblist = Bio.PDB.PDBList()