Fixed (?) numba bug

This commit is contained in:
Jérôme Tubiana
2021-12-28 11:15:35 +02:00
parent fe5a172f52
commit 8ba8dd2b39
2 changed files with 17 additions and 13 deletions

View File

@@ -18,7 +18,7 @@ def make_PR_curves(
title = '',
figsize=(10, 10),
margin=0.05,grid=0.1
,fs=25):
,fs=25,legend_fs=15):
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
@@ -53,14 +53,14 @@ def make_PR_curves(
fig, ax = plt.subplots(figsize=figsize)
for i in range(nSubsets):
ax.plot(all_PR_curves[i][1], all_PR_curves[i][0], color=subsetColors[i],linewidth=2.0,
label='%s (AUCPR= %.3f)' % (subset_names[i], all_AUCPRs[i]))
label='%s (%.3f)' % (subset_names[i], all_AUCPRs[i]))
plt.xticks(np.arange(0, 1.0 + grid, grid), fontsize=fs * 2/3)
plt.yticks(np.arange(0, 1.0 + grid, grid), fontsize=fs * 2/3)
plt.xlim([0 - margin, 1 + margin])
plt.ylim([0 - margin, 1 + margin])
plt.grid()
plt.legend(fontsize=fs)
plt.legend(fontsize=legend_fs)
plt.xlabel('Recall', fontsize=fs)
plt.ylabel('Precision', fontsize=fs)
plt.title(title,fontsize=fs)

View File

@@ -305,16 +305,20 @@ def _get_aa_frameCloud_quadruplet(atom_coordinates, atom_ids, verbose=True):
def add_virtual_atoms(atom_clouds, atom_triplets, verbose=True):
virtual_atom_clouds, atom_triplets = _add_virtual_atoms(atom_clouds, atom_triplets, verbose=verbose)
if np.abs(virtual_atom_clouds).max() >1e8:
print('The weird numba bug happened again at add_virtual_atoms, rerunning once')
virtual_atom_clouds, atom_triplets = _add_virtual_atoms(atom_clouds, atom_triplets, verbose=verbose)
if np.abs(virtual_atom_clouds).max() > 1e8:
print('The weird numba bug persists...')
else:
print('The weird numba bug was fixed by rerunning')
if len(virtual_atom_clouds) > 0:
virtual_atom_clouds = np.array(virtual_atom_clouds)
if np.abs(virtual_atom_clouds).max() >1e8:
print('The weird numba bug happened again at add_virtual_atoms, need to fix virtual atoms')
weird_indices = np.nonzero(np.abs(virtual_atom_clouds).max(-1) >1e8 )[0]
print('Fixing %s virtual atoms'%len(weird_indices))
original_atom_indices = np.array([np.nonzero((atom_triplets[:,1:] == len(atom_triplets)+ index).max(-1))[0][0] for index in weird_indices])
print(weird_indices,original_atom_indices)
for weird_index, original_atom_index in zip(weird_indices,original_atom_indices):
virtual_atom_clouds[weird_index] = atom_clouds[original_atom_index,:]
if atom_triplets[original_atom_index,1] == weird_index:
virtual_atom_clouds[weird_index][0] +=1
else:
virtual_atom_clouds[weird_index][2] += 1
atom_clouds = np.concatenate([atom_clouds, np.array(virtual_atom_clouds)], axis=0)
return atom_clouds, atom_triplets
@@ -377,8 +381,8 @@ def _add_virtual_atoms(atom_clouds, atom_triplets, verbose=True):
if __name__ == '__main__':
import PDB_processing
import Bio.PDB
from preprocessing import PDBio,PDB_processing
PDB_folder = '/Users/jerometubiana/PDB/'
pdblist = Bio.PDB.PDBList()