mirror of
https://github.com/jertubiana/ScanNet.git
synced 2026-06-04 13:44:22 +08:00
Fixed (?) numba bug
This commit is contained in:
@@ -18,7 +18,7 @@ def make_PR_curves(
|
||||
title = '',
|
||||
figsize=(10, 10),
|
||||
margin=0.05,grid=0.1
|
||||
,fs=25):
|
||||
,fs=25,legend_fs=15):
|
||||
import matplotlib
|
||||
matplotlib.use('agg')
|
||||
import matplotlib.pyplot as plt
|
||||
@@ -53,14 +53,14 @@ def make_PR_curves(
|
||||
fig, ax = plt.subplots(figsize=figsize)
|
||||
for i in range(nSubsets):
|
||||
ax.plot(all_PR_curves[i][1], all_PR_curves[i][0], color=subsetColors[i],linewidth=2.0,
|
||||
label='%s (AUCPR= %.3f)' % (subset_names[i], all_AUCPRs[i]))
|
||||
label='%s (%.3f)' % (subset_names[i], all_AUCPRs[i]))
|
||||
plt.xticks(np.arange(0, 1.0 + grid, grid), fontsize=fs * 2/3)
|
||||
plt.yticks(np.arange(0, 1.0 + grid, grid), fontsize=fs * 2/3)
|
||||
plt.xlim([0 - margin, 1 + margin])
|
||||
plt.ylim([0 - margin, 1 + margin])
|
||||
plt.grid()
|
||||
|
||||
plt.legend(fontsize=fs)
|
||||
plt.legend(fontsize=legend_fs)
|
||||
plt.xlabel('Recall', fontsize=fs)
|
||||
plt.ylabel('Precision', fontsize=fs)
|
||||
plt.title(title,fontsize=fs)
|
||||
|
||||
@@ -305,16 +305,20 @@ def _get_aa_frameCloud_quadruplet(atom_coordinates, atom_ids, verbose=True):
|
||||
|
||||
def add_virtual_atoms(atom_clouds, atom_triplets, verbose=True):
|
||||
virtual_atom_clouds, atom_triplets = _add_virtual_atoms(atom_clouds, atom_triplets, verbose=verbose)
|
||||
if np.abs(virtual_atom_clouds).max() >1e8:
|
||||
print('The weird numba bug happened again at add_virtual_atoms, rerunning once')
|
||||
virtual_atom_clouds, atom_triplets = _add_virtual_atoms(atom_clouds, atom_triplets, verbose=verbose)
|
||||
if np.abs(virtual_atom_clouds).max() > 1e8:
|
||||
print('The weird numba bug persists...')
|
||||
else:
|
||||
print('The weird numba bug was fixed by rerunning')
|
||||
|
||||
|
||||
if len(virtual_atom_clouds) > 0:
|
||||
virtual_atom_clouds = np.array(virtual_atom_clouds)
|
||||
if np.abs(virtual_atom_clouds).max() >1e8:
|
||||
print('The weird numba bug happened again at add_virtual_atoms, need to fix virtual atoms')
|
||||
weird_indices = np.nonzero(np.abs(virtual_atom_clouds).max(-1) >1e8 )[0]
|
||||
print('Fixing %s virtual atoms'%len(weird_indices))
|
||||
original_atom_indices = np.array([np.nonzero((atom_triplets[:,1:] == len(atom_triplets)+ index).max(-1))[0][0] for index in weird_indices])
|
||||
print(weird_indices,original_atom_indices)
|
||||
for weird_index, original_atom_index in zip(weird_indices,original_atom_indices):
|
||||
virtual_atom_clouds[weird_index] = atom_clouds[original_atom_index,:]
|
||||
if atom_triplets[original_atom_index,1] == weird_index:
|
||||
virtual_atom_clouds[weird_index][0] +=1
|
||||
else:
|
||||
virtual_atom_clouds[weird_index][2] += 1
|
||||
atom_clouds = np.concatenate([atom_clouds, np.array(virtual_atom_clouds)], axis=0)
|
||||
return atom_clouds, atom_triplets
|
||||
|
||||
@@ -377,8 +381,8 @@ def _add_virtual_atoms(atom_clouds, atom_triplets, verbose=True):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import PDB_processing
|
||||
import Bio.PDB
|
||||
from preprocessing import PDBio,PDB_processing
|
||||
|
||||
PDB_folder = '/Users/jerometubiana/PDB/'
|
||||
pdblist = Bio.PDB.PDBList()
|
||||
|
||||
Reference in New Issue
Block a user