mirror of
https://github.com/gcorso/DiffDock.git
synced 2026-06-04 18:04:23 +08:00
147 lines
4.2 KiB
Python
147 lines
4.2 KiB
Python
# From Nick Polizzi
|
|
import numpy as np
|
|
from collections import defaultdict
|
|
import prody as pr
|
|
import os
|
|
|
|
from datasets.constants import chi, atom_order, aa_long2short, aa_short2aa_idx, aa_idx2aa_short
|
|
|
|
|
|
def get_dihedral_indices(resname, chi_num):
|
|
"""Return the atom indices for the specified dihedral angle.
|
|
"""
|
|
if resname not in chi:
|
|
return np.array([np.nan]*4)
|
|
if chi_num not in chi[resname]:
|
|
return np.array([np.nan]*4)
|
|
return np.array([atom_order[resname].index(x) for x in chi[resname][chi_num]])
|
|
|
|
|
|
dihedral_indices = defaultdict(list)
|
|
for aa in atom_order.keys():
|
|
for i in range(1, 5):
|
|
inds = get_dihedral_indices(aa, i)
|
|
dihedral_indices[aa].append(inds)
|
|
dihedral_indices[aa] = np.array(dihedral_indices[aa])
|
|
|
|
|
|
def vector_batch(a, b):
|
|
return a - b
|
|
|
|
|
|
def unit_vector_batch(v):
|
|
return v / np.linalg.norm(v, axis=1, keepdims=True)
|
|
|
|
|
|
def dihedral_angle_batch(p):
|
|
b0 = vector_batch(p[:, 0], p[:, 1])
|
|
b1 = vector_batch(p[:, 1], p[:, 2])
|
|
b2 = vector_batch(p[:, 2], p[:, 3])
|
|
|
|
n1 = np.cross(b0, b1)
|
|
n2 = np.cross(b1, b2)
|
|
|
|
m1 = np.cross(n1, b1 / np.linalg.norm(b1, axis=1, keepdims=True))
|
|
|
|
x = np.sum(n1 * n2, axis=1)
|
|
y = np.sum(m1 * n2, axis=1)
|
|
|
|
deg = np.degrees(np.arctan2(y, x))
|
|
|
|
deg[deg < 0] += 360
|
|
|
|
return deg
|
|
|
|
|
|
def batch_compute_dihedral_angles(sidechains):
|
|
sidechains_np = np.array(sidechains)
|
|
dihedral_angles = dihedral_angle_batch(sidechains_np)
|
|
return dihedral_angles
|
|
|
|
|
|
def get_coords(prody_pdb):
|
|
resindices = sorted(set(prody_pdb.ca.getResindices()))
|
|
coords = np.zeros((len(resindices), 14, 3))
|
|
for i, resind in enumerate(resindices):
|
|
sel = prody_pdb.select(f'resindex {resind}')
|
|
resname = sel.getResnames()[0]
|
|
for j, name in enumerate(atom_order[aa_long2short[resname] if resname in aa_long2short else 'X']):
|
|
sel_resnum_name = sel.select(f'name {name}')
|
|
if sel_resnum_name is not None:
|
|
coords[i, j, :] = sel_resnum_name.getCoords()[0]
|
|
else:
|
|
coords[i, j, :] = [np.nan, np.nan, np.nan]
|
|
return coords
|
|
|
|
|
|
def get_onehot_sequence(seq):
|
|
onehot = np.zeros((len(seq), 20))
|
|
for i, aa in enumerate(seq):
|
|
idx = aa_short2aa_idx[aa] if aa in aa_short2aa_idx else 7 # 7 is the index for GLY
|
|
onehot[i, idx] = 1
|
|
return onehot
|
|
|
|
|
|
def get_dihedral_indices(onehot_sequence):
|
|
return np.array([dihedral_indices[aa_idx2aa_short[aa_idx]] for aa_idx in np.where(onehot_sequence)[1]])
|
|
|
|
|
|
def _get_chi_angles(coords, indices):
|
|
X = coords
|
|
Y = indices.astype(int)
|
|
N = coords.shape[0]
|
|
mask = np.isnan(indices)
|
|
Y[mask] = 0
|
|
Z = X[np.arange(N)[:, None, None], Y, :]
|
|
Z[mask] = np.nan
|
|
chi_angles = batch_compute_dihedral_angles(Z.reshape(-1, 4, 3)).reshape(N, 4)
|
|
return chi_angles
|
|
|
|
|
|
def get_chi_angles(coords, seq, return_onehot=False):
|
|
"""
|
|
|
|
Parameters
|
|
----------
|
|
prody_pdb : prody.AtomGroup
|
|
prody pdb object or selection
|
|
return_coords : bool, optional
|
|
return coordinates of prody_pdb in (N, 14, 3) array format, by default False
|
|
return_onehot : bool, optional
|
|
return one-hot sequence of prody_pdb, by default False
|
|
|
|
Returns
|
|
-------
|
|
numpy array of shape (N, 4)
|
|
Array contains chi angles of sidechains in row-order of residue indices in prody_pdb.
|
|
If a chi angle is not defined for a residue, due to missing atoms or GLY / ALA, it is set to np.nan.
|
|
"""
|
|
onehot = get_onehot_sequence(seq)
|
|
dihedral_indices = get_dihedral_indices(onehot)
|
|
if return_onehot:
|
|
return _get_chi_angles(coords, dihedral_indices), onehot
|
|
return _get_chi_angles(coords, dihedral_indices)
|
|
|
|
|
|
def test_get_chi_angles(print_chi_angles=False):
|
|
# need internet connection of '6w70.pdb' in working directory
|
|
pdb = pr.parsePDB('6w70')
|
|
prody_pdb = pdb.select('chain A')
|
|
chi_angles = get_chi_angles(prody_pdb)
|
|
assert chi_angles.shape == (prody_pdb.ca.numAtoms(), 4)
|
|
assert chi_angles[0,0] < 56.0 and chi_angles[0,0] > 55.0
|
|
print('test_get_chi_angles passed')
|
|
try:
|
|
os.remove('6w70.pdb.gz')
|
|
except:
|
|
pass
|
|
if print_chi_angles:
|
|
print(chi_angles)
|
|
return True
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_get_chi_angles(print_chi_angles=True)
|
|
|
|
|