dev on atom protein representations

This commit is contained in:
Rohith Krishna
2022-08-01 11:12:41 -07:00
parent f7679cb34c
commit 128d1b420e
2 changed files with 117 additions and 1 deletions

View File

@@ -5,6 +5,7 @@ from torch.utils import data
# from chemical import NFRAMES
from data_loader import get_train_valid_set, Dataset, DatasetNAComplex, DatasetRNA, DatasetSMComplex, loader_pdb, loader_na_complex, loader_rna, loader_sm_compl,set_data_loader_params
from kinematics import xyz_to_c6d, xyz_to_t2d
from chemical import num2aa, aa2elt, aa2num
from loss import compute_general_FAPE, resolve_equiv_natives, calc_str_loss
from util import get_frames, frame_indices, is_atom, xyz_to_frame_xyz, xyz_t_to_frame_xyz
@@ -250,6 +251,81 @@ class LossTestCase(unittest.TestCase):
self.assertEqual(res_mask.shape[0], B)
self.assertEqual(res_mask.shape[1], L)
break
class DataLoaderTestCase(unittest.TestCase):
def setUp(self) -> None:
super().setUp()
self.loader_param = set_data_loader_params({})
(
pdb_items, fb_items, compl_items, neg_items, na_compl_items, na_neg_items, rna_items,
sm_compl_items, valid_pdb, valid_homo, valid_compl, valid_neg, valid_na_compl,
valid_na_neg, valid_rna, valid_sm_compl, homo
) = get_train_valid_set(self.loader_param)
pdb_IDs, pdb_weights, pdb_dict = pdb_items
na_compl_IDs, na_compl_weights, na_compl_dict = na_compl_items
rna_IDs, rna_weights, rna_dict = rna_items
sm_compl_IDs, sm_compl_weights, sm_compl_dict = sm_compl_items
self.homo = homo
valid_pdb_set = Dataset(
list(valid_pdb.keys()),
loader_pdb, valid_pdb,
self.loader_param, homo, p_homo_cut=-1.0
)
self.valid_pdb_loader = data.DataLoader(valid_pdb_set)
def test_vaporize_protein(self):
for seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames, bond_feats in self.valid_pdb_loader:
B, L = msa[:, 0, 0].shape
i = torch.randint(2, L-3,(1,))
residues_atomize = msa[:, 0, 0, i-2:i+2]
residues_atomize = [aa2elt[num][:14] for num in residues_atomize[0]]
lig_seq = []
ra = []
for idx in range(len(residues_atomize)):
for jdx in range(14):
if residues_atomize[idx][jdx]:
ra.append((i-2+idx, jdx))
lig_seq.append(aa2num[residues_atomize[idx][jdx]])
lig_seq = torch.tensor(lig_seq)
ins = torch.zeros_like(lig_seq)
print(lig_seq)
ra = torch.tensor(ra)
r,a = ra.T
lig_xyz = torch.zeros((len(ra), 3))
lig_xyz = true_crds[:, r, a]
lig_mask = atom_mask[:, r, a]
# print(lig_xyz)
# print(lig_mask)
#NEED TO FIGURE OUT XYZ_T, T1D set everything vaporized into NaN and then create new NaN features for length, set t1D into gaps
msa = torch.cat((msa[:, :, :, :i-2], msa[:, :, :, i+2:]), dim=3)
msa_masked = torch.cat((msa_masked[:, :, :, :i-2], msa_masked[:, :, :, i+2:]), dim=3)
msa_full = torch.cat((msa_full[:, :, :, :i-2], msa_full[:, :, :, i+2:]), dim=3)
mask_msa = torch.cat((mask_msa[:, :, :, :i-2], mask_msa[:, :, :, i+2:]), dim=3)
true_crds = torch.cat((true_crds[ :, :i-2], true_crds[ :, i+2:]), dim=1)
atom_mask = torch.cat((atom_mask[ :, :i-2], atom_mask[ :, i+2:]), dim=1)
idx_pdb = torch.cat((idx_pdb[ :, :i-2], idx_pdb[ :, i+2:]), dim=1)
print(idx_pdb.shape)
xyz_t = torch.cat((xyz_t[ :, :, :i-2], xyz_t[:, :, i+2:]), dim=2)
print(xyz_t.shape)
t1d = torch.cat((t1d[ :, :, :i-2], t1d[:, :, i+2:]), dim=2)
print(t1d.shape)
xyz_prev = torch.cat((xyz_prev[ :, :i-2], xyz_prev[:, i+2:]), dim=1)
print(xyz_prev.shape)
same_chain = torch.cat((same_chain[ :, :i-2], same_chain[:, i+2:]), dim=1)
same_chain = torch.cat((same_chain[ :, :, :i-2], same_chain[:, :, i+2:]), dim=2)
print(same_chain.shape)
print(bond_feats.shape)
bond_feats = torch.cat((bond_feats[ :, :i-2], bond_feats[:, i+2:]), dim=1)
bond_feats = torch.cat((bond_feats[ :, :, :i-2], bond_feats[:, :, i+2:]), dim=2)
print(bond_feats.shape)
break
if __name__ == '__main__':
unittest.main()

View File

@@ -858,7 +858,6 @@ def get_atom_frames(msa, mol, G):
assert msa.shape[0] == len(selected_frames)
return torch.tensor(selected_frames).long()
### Generate bond features for small molecules ###
def get_bond_feats(mol, G):
"""creates 2d bond graph for small molecules"""
@@ -878,3 +877,44 @@ def get_protein_bond_feats(protein_L):
bond_feats[residues, residues+1] = 1
bond_feats[residues+1, residues] = 1
return bond_feats
def atomize_protein(i, msa, true_crds, atom_mask):
""" given an index i, make the preceding 2 residues and the following residue (4 total) into "atom" nodes """
residues_atomize = msa[0, 0, i-2:i+2]
residues_atomize = [aa2elt[num][:14] for num in residues_atomize[0]]
lig_seq = []
ra = []
for idx in range(len(residues_atomize)):
for jdx in range(14):
if residues_atomize[idx][jdx]:
ra.append((i-2+idx, jdx))
lig_seq.append(aa2num[residues_atomize[idx][jdx]])
lig_seq = torch.tensor(lig_seq)
ins = torch.zeros_like(lig_seq)
ra = torch.tensor(ra)
r,a = ra.T
lig_xyz = torch.zeros((len(ra), 3))
lig_xyz = true_crds[r, a]
lig_mask = atom_mask[r, a]
return lig_seq, ins, lig_xyz, lig_mask
def remove_protein_info(i, seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, bond_feats):
""" remove the msa/template information for the portion of the protein that was "atomized" """
seq = torch.cat((seq[:, :i-2], seq[:, i+2:]), dim=1)
msa = torch.cat((msa[:, :, :i-2], msa[:, :, i+2:]), dim=2)
msa_masked = torch.cat((msa_masked[:, :, :, :i-2], msa_masked[:, :, i+2:]), dim=2)
msa_full = torch.cat((msa_full[:, :, :, :i-2], msa_full[:, :, i+2:]), dim=2)
mask_msa = torch.cat((mask_msa[:, :, :, :i-2], mask_msa[:, :, i+2:]), dim=2)
true_crds = torch.cat((true_crds[ :, :i-2], true_crds[ :, i+2:]), dim=1)
atom_mask = torch.cat((atom_mask[ :, :i-2], atom_mask[ :, i+2:]), dim=1)
idx_pdb = torch.cat((idx_pdb[ :, :i-2], idx_pdb[ :, i+2:]), dim=1)
xyz_t = torch.cat((xyz_t[ :, :, :i-2], xyz_t[:, :, i+2:]), dim=2)
t1d = torch.cat((t1d[ :, :, :i-2], t1d[:, :, i+2:]), dim=2)
xyz_prev = torch.cat((xyz_prev[ :, :i-2], xyz_prev[:, i+2:]), dim=1)
same_chain = torch.cat((same_chain[ :i-2], same_chain[:, i+2:]), dim=1)
same_chain = torch.cat((same_chain[ :, :i-2], same_chain[:, i+2:]), dim=2)
bond_feats = torch.cat((bond_feats[ :i-2], bond_feats[i+2:]), dim=1)
bond_feats = torch.cat((bond_feats[ :, :i-2], bond_feats[:, i+2:]), dim=2)
return seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, bond_feats