from preprocessing.protein_chemistry import dictionary_covalent_bonds_numba, atom_type_mass, list_atoms, aa_to_index from numba import njit, types from numba.typed import List, Dict import numpy as np def get_atom_frameCloud(sequence, atom_coordinates, atom_ids): atom_clouds = np.concatenate(atom_coordinates, axis=0) atom_attributes = np.concatenate(atom_ids, axis=-1) atom_triplets = np.array(_get_atom_triplets(sequence, List(atom_ids), dictionary_covalent_bonds_numba), dtype=np.int32) atom_indices = np.concatenate([np.ones(len(atom_ids[l]), dtype=np.int32) * l for l in range(len(sequence))], axis=-1)[:, np.newaxis] return atom_clouds, atom_triplets, atom_attributes, atom_indices @njit(parallel=False, cache=False) def _get_atom_triplets(sequence, atom_ids, dictionary_covalent_bonds_numba): L = len(sequence) atom_triplets = List() all_keys = List(dictionary_covalent_bonds_numba.keys() ) current_natoms = 0 for l in range(L): aa = sequence[l] atom_id = atom_ids[l] natoms = len(atom_id) for n in range(natoms): id = atom_id[n] if (id == 17): # N, special case, bound to C of previous aa. if l > 0: if 0 in atom_ids[l - 1]: previous = current_natoms - len(atom_ids[l - 1]) + atom_ids[l - 1].index(0) else: previous = -1 else: previous = -1 if 1 in atom_id: next = current_natoms + atom_id.index(1) else: next = -1 elif (id == 0): # C, special case, bound to N of next aa. if 1 in atom_id: previous = current_natoms + atom_id.index(1) else: previous = -1 if l < L - 1: if 17 in atom_ids[l + 1]: next = current_natoms + natoms + atom_ids[l + 1].index(17) else: next = -1 else: next = -1 else: key = (aa + '_' + str(id) ) if key in all_keys: previous_id, next_id, _ = dictionary_covalent_bonds_numba[(aa + '_' + str(id) )] else: print('Strange atom', (aa + '_' + str(id) )) previous_id = -1 next_id = -1 if previous_id in atom_id: previous = current_natoms + atom_id.index(previous_id) else: previous = -1 if next_id in atom_id: next = current_natoms + atom_id.index(next_id) else: next = -1 atom_triplets.append((current_natoms + n, previous, next)) current_natoms += natoms return atom_triplets def get_aa_frameCloud(atom_coordinates, atom_ids, verbose=True, method='triplet_backbone'): if method == 'triplet_backbone': get_aa_frameCloud_ = _get_aa_frameCloud_triplet_backbone elif method == 'triplet_sidechain': get_aa_frameCloud_ = _get_aa_frameCloud_triplet_sidechain elif method == 'triplet_cbeta': get_aa_frameCloud_ = _get_aa_frameCloud_triplet_cbeta elif method == 'quadruplet': get_aa_frameCloud_ = _get_aa_frameCloud_quadruplet aa_clouds, aa_triplets = get_aa_frameCloud_(List(atom_coordinates), List(atom_ids), verbose=verbose) aa_indices = np.arange(len(atom_coordinates)).astype(np.int32)[:, np.newaxis] aa_clouds = np.array(aa_clouds) aa_triplets = np.array(aa_triplets, dtype=np.int32) return aa_clouds, aa_triplets, aa_indices @njit(cache=True, parallel=False) def _get_aa_frameCloud_triplet_backbone(atom_coordinates, atom_ids, verbose=True): L = len(atom_coordinates) aa_clouds = List() aa_triplets = List() for l in range(L): atom_coordinate = atom_coordinates[l] atom_id = atom_ids[l] natoms = len(atom_id) if 1 in atom_id: calpha_coordinate = atom_coordinate[atom_id.index(1)] else: if verbose: print('Warning, pathological amino acid missing calpha', l) calpha_coordinate = atom_coordinate[0] aa_clouds.append(calpha_coordinate) # Add virtual calpha at beginning and at the end. aa_clouds.append(aa_clouds[0] + (aa_clouds[1] - aa_clouds[2])) aa_clouds.append(aa_clouds[L - 1] + (aa_clouds[L - 2] - aa_clouds[L - 3])) for l in range(L): center = l if l == 0: previous = L else: previous = l - 1 if l == L - 1: next = L + 1 else: next = l + 1 aa_triplets.append((center, previous, next)) return aa_clouds, aa_triplets @njit(cache=True, parallel=False) def _get_aa_frameCloud_triplet_sidechain(atom_coordinates, atom_ids, verbose=True): L = len(atom_coordinates) aa_clouds = List() aa_triplets = List() count = 0 for l in range(L): atom_coordinate = atom_coordinates[l] atom_id = atom_ids[l] natoms = len(atom_id) if 1 in atom_id: calpha_coordinate = atom_coordinate[atom_id.index(1)] else: if verbose: print('Warning, pathological amino acid missing calpha', l) calpha_coordinate = atom_coordinate[0] center = 1 * count aa_clouds.append(calpha_coordinate) count += 1 if count > 1: previous = aa_triplets[-1][0] else: # Need to place another virtual Calpha. virtual_calpha_coordinate = 2 * calpha_coordinate - atom_coordinates[1][0] aa_clouds.append(virtual_calpha_coordinate) previous = 1 * count count += 1 sidechain_CoM = np.zeros(3, dtype=np.float32) sidechain_mass = 0. for n in range(natoms): if not atom_id[n] in [0, 1, 17, 26, 34]: mass = atom_type_mass[atom_id[n]] sidechain_CoM += mass * atom_coordinate[n] sidechain_mass += mass if sidechain_mass > 0: sidechain_CoM /= sidechain_mass else: # Usually case of Glycin #''' #TO CHANGE FOR NEXT NETWORK ITERATION... I used the wrong nitrogen when I rewrote the function... if l>0: if (0 in atom_id) & (1 in atom_id) & (17 in atom_ids[l-1]): # If C,N,Calpha are here, place virtual CoM sidechain_CoM = 3 * atom_coordinate[atom_id.index(1)] - atom_coordinates[l-1][atom_ids[l-1].index(17)] - \ atom_coordinate[atom_id.index(0)] else: if verbose: print('Warning, pathological amino acid missing side chain and backbone', l) sidechain_CoM = atom_coordinate[-1] else: if verbose: print('Warning, pathological amino acid missing side chain and backbone', l) sidechain_CoM = atom_coordinate[-1] #''' # if (0 in atom_id) & (1 in atom_id) & (17 in atom_id): # If C,N,Calpha are here, place virtual CoM # sidechain_CoM = 3 * atom_coordinate[atom_id.index(1)] - atom_coordinate[atom_id.index(17)] - \ # atom_coordinate[atom_id.index(0)] # else: # if verbose: # print('Warning, pathological amino acid missing side chain and backbone', l) # sidechain_CoM = atom_coordinate[-1] aa_clouds.append(sidechain_CoM) next = 1 * count count += 1 aa_triplets.append((center, previous, next)) return aa_clouds, aa_triplets @njit(cache=True, parallel=False) def _get_aa_frameCloud_triplet_cbeta(atom_coordinates, atom_ids, verbose=True): L = len(atom_coordinates) aa_clouds = List() aa_triplets = List() count = 0 for l in range(L): atom_coordinate = atom_coordinates[l] atom_id = atom_ids[l] natoms = len(atom_id) if 1 in atom_id: calpha_coordinate = atom_coordinate[atom_id.index(1)] else: if verbose: print('Warning, pathological amino acid missing calpha', l) calpha_coordinate = atom_coordinate[0] if 2 in atom_id: cbeta_coordinate = atom_coordinate[atom_id.index(2)] else: if (0 in atom_id) & (1 in atom_id) & (17 in atom_id): # If C,N,Calpha are here, place virtual CoM cbeta_coordinate = 3 * atom_coordinate[atom_id.index(1)] - atom_coordinate[atom_id.index(17)] - \ atom_coordinate[atom_id.index(0)] else: if verbose: print('Warning, pathological amino acid missing cbeta and backbone', l) cbeta_coordinate = atom_coordinate[-1] center = 1 * count aa_clouds.append(calpha_coordinate) count += 1 if count > 1: previous = aa_triplets[-1][0] else: # Need to place another virtual Calpha. virtual_calpha_coordinate = 2 * calpha_coordinate - atom_coordinates[1][0] aa_clouds.append(virtual_calpha_coordinate) previous = 1 * count count += 1 aa_clouds.append(cbeta_coordinate) next = 1 * count count += 1 aa_triplets.append((center, previous, next)) return aa_clouds, aa_triplets @njit(cache=True, parallel=False) def _get_aa_frameCloud_quadruplet(atom_coordinates, atom_ids, verbose=True): L = len(atom_coordinates) aa_clouds = List() aa_triplets = List() for l in range(L): atom_coordinate = atom_coordinates[l] atom_id = atom_ids[l] natoms = len(atom_id) if 1 in atom_id: calpha_coordinate = atom_coordinate[atom_id.index(1)] else: if verbose: print('Warning, pathological amino acid missing calpha', l) calpha_coordinate = atom_coordinate[0] aa_clouds.append(calpha_coordinate) # Add virtual calpha at beginning and at the end. aa_clouds.append(aa_clouds[0] + (aa_clouds[1] - aa_clouds[2])) aa_clouds.append(aa_clouds[L - 1] + (aa_clouds[L - 2] - aa_clouds[L - 3])) count = L + 2 for l in range(L): atom_coordinate = atom_coordinates[l] atom_id = atom_ids[l] natoms = len(atom_id) sidechain_CoM = np.zeros(3, dtype=np.float32) sidechain_mass = 0. for n in range(natoms): if not atom_id[n] in [0, 1, 17, 26, 34]: mass = atom_type_mass[atom_id[n]] sidechain_CoM += mass * atom_coordinate[n] sidechain_mass += mass if sidechain_mass > 0: sidechain_CoM /= sidechain_mass else: # Usually case of Glycin if (0 in atom_id) & (1 in atom_id) & (17 in atom_id): # If C,N,Calpha are here, place virtual CoM sidechain_CoM = 3 * atom_coordinate[atom_id.index(1)] - atom_coordinate[atom_id.index(17)] - \ atom_coordinate[atom_id.index(0)] else: if verbose: print('Warning, pathological amino acid missing side chain and backbone', l) sidechain_CoM = atom_coordinate[-1] aa_clouds.append(sidechain_CoM) center = l if l == 0: previous = L else: previous = l - 1 if l == L - 1: next = L + 1 else: next = l + 1 dipole = L + 2 + l aa_triplets.append((center, previous, next, dipole)) return aa_clouds, aa_triplets def add_virtual_atoms(atom_clouds, atom_triplets, verbose=True): virtual_atom_clouds, atom_triplets = _add_virtual_atoms(atom_clouds, atom_triplets, verbose=verbose) if len(virtual_atom_clouds) > 0: virtual_atom_clouds = np.array(virtual_atom_clouds) if np.abs(virtual_atom_clouds).max() >1e8: print('The weird numba bug happened again at add_virtual_atoms, need to fix virtual atoms') weird_indices = np.nonzero(np.abs(virtual_atom_clouds).max(-1) >1e8 )[0] print('Fixing %s virtual atoms'%len(weird_indices)) original_atom_indices = np.array([np.nonzero((atom_triplets[:,1:] == len(atom_triplets)+ index).max(-1))[0][0] for index in weird_indices]) print(weird_indices,original_atom_indices) for weird_index, original_atom_index in zip(weird_indices,original_atom_indices): virtual_atom_clouds[weird_index] = atom_clouds[original_atom_index,:] if atom_triplets[original_atom_index,1] == weird_index: virtual_atom_clouds[weird_index][0] +=1 else: virtual_atom_clouds[weird_index][2] += 1 atom_clouds = np.concatenate([atom_clouds, np.array(virtual_atom_clouds)], axis=0) return atom_clouds, atom_triplets @njit(cache=False) def _add_virtual_atoms(atom_clouds, atom_triplets, verbose=True): natoms = len(atom_triplets) virtual_atom_clouds = List() count_virtual_atoms = 0 centers = list(atom_triplets[:, 0]) for n in range(natoms): triplet = atom_triplets[n] case1 = (triplet[1] >= 0) & (triplet[2] >= 0) case2 = (triplet[1] < 0) & (triplet[2] >= 0) case3 = (triplet[1] >= 0) & (triplet[2] < 0) case4 = (triplet[1] < 0) & (triplet[2] < 0) if case1: # Atom has at least two covalent bonds. continue elif case2: # Atom has one covalent bond. Previous is missing, next is present (Either N-terminal N or missing atom). next_triplet = atom_triplets[centers.index(triplet[2])] if next_triplet[2] >= 0: # Next of next is present. Build virtual atom to obtain parallelogram. virtual_atom = atom_clouds[next_triplet[0]] - atom_clouds[next_triplet[2]] + atom_clouds[triplet[0]] else: # Next of next is also absent. Pathological case, use absolute x direction... if verbose: print('Pathological case, atom has only one bond and its next partner too', triplet[0], triplet[2]) # print('Pathological case, atom %s has only one bond and its next partner %s too'%(triplet[0],triplet[2])) virtual_atom = atom_clouds[triplet[0]] + np.array([1, 0, 0]) virtual_atom_clouds.append(virtual_atom) triplet[1] = natoms + count_virtual_atoms count_virtual_atoms += 1 elif case3: # Atom has one covalent bond. Next is missing, previous is present (Either C-terminal C or missing atom). previous_triplet = atom_triplets[centers.index(triplet[1])] if previous_triplet[1] >= 0: # Previous of previous is present. Build virtual atom to obtain parallelogram. virtual_atom = atom_clouds[previous_triplet[0]] - atom_clouds[previous_triplet[1]] + atom_clouds[ triplet[0]] else: # Previous of previous is also absent. Pathological case, use absolute z direction... if verbose: print('Pathological case, atom has only one bond and its previous partner too', triplet[0], triplet[1]) # print('Pathological case, atom %s has only one bond and its previous partner %s too'%(triplet[0],triplet[1])) virtual_atom = atom_clouds[triplet[0]] + np.array([0, 0, 1]) virtual_atom_clouds.append(virtual_atom) triplet[2] = natoms + count_virtual_atoms count_virtual_atoms += 1 elif case4: # Atom has no covalent bonds. Should never happen, use absolute coordinates. if verbose: print('Pathological case, atom has no bonds at all', triplet[0]) # print('Pathological case, atom %s has no bonds at all' %triplet[0]) virtual_previous_atom = atom_clouds[triplet[0]] + np.array([1, 0, 0]) virtual_next_atom = atom_clouds[triplet[0]] + np.array([0, 0, 1]) virtual_atom_clouds.append(virtual_previous_atom) virtual_atom_clouds.append(virtual_next_atom) triplet[1] = natoms + count_virtual_atoms triplet[2] = natoms + count_virtual_atoms + 1 count_virtual_atoms += 2 return virtual_atom_clouds, atom_triplets if __name__ == '__main__': import Bio.PDB from preprocessing import PDBio,PDB_processing PDB_folder = '/Users/jerometubiana/PDB/' pdblist = Bio.PDB.PDBList() # pdb = '1a3x' pdb = '2kho' chain = 'A' name = pdblist.retrieve_pdb_file(pdb, pdir=PDB_folder) struct, chains = PDBio.load_chains(pdb_id=pdb, chain_ids=[(0, chain)], file=PDB_folder + '%s.cif' % pdb) sequence, backbone_coordinates, atom_coordinates, atom_ids, atom_types = PDB_processing.process_chain(chains) atom_clouds, atom_triplets, atom_attributes, atom_indices = get_atom_frameCloud(sequence, atom_coordinates, atom_ids) for i in range(20): tmp = atom_triplets[i, :] center = list_atoms[atom_attributes[tmp[0]]] if tmp[1] >= 0: previous = list_atoms[atom_attributes[tmp[1]]] else: previous = 'NONE' if tmp[2] >= 0: next = list_atoms[atom_attributes[tmp[2]]] else: next = 'NONE' print(i, center, previous, next) atom_clouds_filled, atom_triplets_filled = add_virtual_atoms(atom_clouds, atom_triplets, verbose=True) aa_clouds, aa_triplets, aa_indices = get_aa_frameCloud(atom_coordinates, atom_ids, verbose=True) aa_attributes = np.array([aa_to_index[aa] for aa in sequence], dtype=np.int32) inputs2network = [ aa_triplets, aa_indices, aa_clouds, aa_attributes, atom_triplets_filled, atom_indices, atom_clouds_filled, ] for input in inputs2network: print(input.shape, input.dtype)