From 128d1b420e4535ffa4677ddf8662323e43726340 Mon Sep 17 00:00:00 2001 From: Rohith Krishna Date: Mon, 1 Aug 2022 11:12:41 -0700 Subject: [PATCH] dev on atom protein representations --- RF2_allatom/tests.py | 76 ++++++++++++++++++++++++++++++++++++++++++++ RF2_allatom/util.py | 42 +++++++++++++++++++++++- 2 files changed, 117 insertions(+), 1 deletion(-) diff --git a/RF2_allatom/tests.py b/RF2_allatom/tests.py index dd0cf3d..f615b61 100644 --- a/RF2_allatom/tests.py +++ b/RF2_allatom/tests.py @@ -5,6 +5,7 @@ from torch.utils import data # from chemical import NFRAMES from data_loader import get_train_valid_set, Dataset, DatasetNAComplex, DatasetRNA, DatasetSMComplex, loader_pdb, loader_na_complex, loader_rna, loader_sm_compl,set_data_loader_params from kinematics import xyz_to_c6d, xyz_to_t2d +from chemical import num2aa, aa2elt, aa2num from loss import compute_general_FAPE, resolve_equiv_natives, calc_str_loss from util import get_frames, frame_indices, is_atom, xyz_to_frame_xyz, xyz_t_to_frame_xyz @@ -250,6 +251,81 @@ class LossTestCase(unittest.TestCase): self.assertEqual(res_mask.shape[0], B) self.assertEqual(res_mask.shape[1], L) break + + +class DataLoaderTestCase(unittest.TestCase): + + def setUp(self) -> None: + super().setUp() + self.loader_param = set_data_loader_params({}) + ( + pdb_items, fb_items, compl_items, neg_items, na_compl_items, na_neg_items, rna_items, + sm_compl_items, valid_pdb, valid_homo, valid_compl, valid_neg, valid_na_compl, + valid_na_neg, valid_rna, valid_sm_compl, homo + ) = get_train_valid_set(self.loader_param) + + pdb_IDs, pdb_weights, pdb_dict = pdb_items + na_compl_IDs, na_compl_weights, na_compl_dict = na_compl_items + rna_IDs, rna_weights, rna_dict = rna_items + sm_compl_IDs, sm_compl_weights, sm_compl_dict = sm_compl_items + self.homo = homo + + valid_pdb_set = Dataset( + list(valid_pdb.keys()), + loader_pdb, valid_pdb, + self.loader_param, homo, p_homo_cut=-1.0 + ) + self.valid_pdb_loader = data.DataLoader(valid_pdb_set) + + def test_vaporize_protein(self): + for seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames, bond_feats in self.valid_pdb_loader: + B, L = msa[:, 0, 0].shape + i = torch.randint(2, L-3,(1,)) + residues_atomize = msa[:, 0, 0, i-2:i+2] + residues_atomize = [aa2elt[num][:14] for num in residues_atomize[0]] + lig_seq = [] + ra = [] + for idx in range(len(residues_atomize)): + for jdx in range(14): + if residues_atomize[idx][jdx]: + ra.append((i-2+idx, jdx)) + lig_seq.append(aa2num[residues_atomize[idx][jdx]]) + lig_seq = torch.tensor(lig_seq) + ins = torch.zeros_like(lig_seq) + print(lig_seq) + ra = torch.tensor(ra) + r,a = ra.T + lig_xyz = torch.zeros((len(ra), 3)) + lig_xyz = true_crds[:, r, a] + lig_mask = atom_mask[:, r, a] + # print(lig_xyz) + # print(lig_mask) + #NEED TO FIGURE OUT XYZ_T, T1D set everything vaporized into NaN and then create new NaN features for length, set t1D into gaps + msa = torch.cat((msa[:, :, :, :i-2], msa[:, :, :, i+2:]), dim=3) + msa_masked = torch.cat((msa_masked[:, :, :, :i-2], msa_masked[:, :, :, i+2:]), dim=3) + msa_full = torch.cat((msa_full[:, :, :, :i-2], msa_full[:, :, :, i+2:]), dim=3) + mask_msa = torch.cat((mask_msa[:, :, :, :i-2], mask_msa[:, :, :, i+2:]), dim=3) + + true_crds = torch.cat((true_crds[ :, :i-2], true_crds[ :, i+2:]), dim=1) + atom_mask = torch.cat((atom_mask[ :, :i-2], atom_mask[ :, i+2:]), dim=1) + + idx_pdb = torch.cat((idx_pdb[ :, :i-2], idx_pdb[ :, i+2:]), dim=1) + print(idx_pdb.shape) + xyz_t = torch.cat((xyz_t[ :, :, :i-2], xyz_t[:, :, i+2:]), dim=2) + print(xyz_t.shape) + t1d = torch.cat((t1d[ :, :, :i-2], t1d[:, :, i+2:]), dim=2) + print(t1d.shape) + xyz_prev = torch.cat((xyz_prev[ :, :i-2], xyz_prev[:, i+2:]), dim=1) + print(xyz_prev.shape) + same_chain = torch.cat((same_chain[ :, :i-2], same_chain[:, i+2:]), dim=1) + same_chain = torch.cat((same_chain[ :, :, :i-2], same_chain[:, :, i+2:]), dim=2) + print(same_chain.shape) + + print(bond_feats.shape) + bond_feats = torch.cat((bond_feats[ :, :i-2], bond_feats[:, i+2:]), dim=1) + bond_feats = torch.cat((bond_feats[ :, :, :i-2], bond_feats[:, :, i+2:]), dim=2) + print(bond_feats.shape) + break if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/RF2_allatom/util.py b/RF2_allatom/util.py index 5c2be43..60ec83b 100644 --- a/RF2_allatom/util.py +++ b/RF2_allatom/util.py @@ -858,7 +858,6 @@ def get_atom_frames(msa, mol, G): assert msa.shape[0] == len(selected_frames) return torch.tensor(selected_frames).long() - ### Generate bond features for small molecules ### def get_bond_feats(mol, G): """creates 2d bond graph for small molecules""" @@ -878,3 +877,44 @@ def get_protein_bond_feats(protein_L): bond_feats[residues, residues+1] = 1 bond_feats[residues+1, residues] = 1 return bond_feats + +def atomize_protein(i, msa, true_crds, atom_mask): + """ given an index i, make the preceding 2 residues and the following residue (4 total) into "atom" nodes """ + residues_atomize = msa[0, 0, i-2:i+2] + residues_atomize = [aa2elt[num][:14] for num in residues_atomize[0]] + lig_seq = [] + ra = [] + for idx in range(len(residues_atomize)): + for jdx in range(14): + if residues_atomize[idx][jdx]: + ra.append((i-2+idx, jdx)) + lig_seq.append(aa2num[residues_atomize[idx][jdx]]) + lig_seq = torch.tensor(lig_seq) + ins = torch.zeros_like(lig_seq) + ra = torch.tensor(ra) + r,a = ra.T + lig_xyz = torch.zeros((len(ra), 3)) + lig_xyz = true_crds[r, a] + lig_mask = atom_mask[r, a] + return lig_seq, ins, lig_xyz, lig_mask + +def remove_protein_info(i, seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, bond_feats): + """ remove the msa/template information for the portion of the protein that was "atomized" """ + seq = torch.cat((seq[:, :i-2], seq[:, i+2:]), dim=1) + msa = torch.cat((msa[:, :, :i-2], msa[:, :, i+2:]), dim=2) + msa_masked = torch.cat((msa_masked[:, :, :, :i-2], msa_masked[:, :, i+2:]), dim=2) + msa_full = torch.cat((msa_full[:, :, :, :i-2], msa_full[:, :, i+2:]), dim=2) + mask_msa = torch.cat((mask_msa[:, :, :, :i-2], mask_msa[:, :, i+2:]), dim=2) + true_crds = torch.cat((true_crds[ :, :i-2], true_crds[ :, i+2:]), dim=1) + atom_mask = torch.cat((atom_mask[ :, :i-2], atom_mask[ :, i+2:]), dim=1) + + idx_pdb = torch.cat((idx_pdb[ :, :i-2], idx_pdb[ :, i+2:]), dim=1) + xyz_t = torch.cat((xyz_t[ :, :, :i-2], xyz_t[:, :, i+2:]), dim=2) + t1d = torch.cat((t1d[ :, :, :i-2], t1d[:, :, i+2:]), dim=2) + xyz_prev = torch.cat((xyz_prev[ :, :i-2], xyz_prev[:, i+2:]), dim=1) + same_chain = torch.cat((same_chain[ :i-2], same_chain[:, i+2:]), dim=1) + same_chain = torch.cat((same_chain[ :, :i-2], same_chain[:, i+2:]), dim=2) + bond_feats = torch.cat((bond_feats[ :i-2], bond_feats[i+2:]), dim=1) + bond_feats = torch.cat((bond_feats[ :, :i-2], bond_feats[:, i+2:]), dim=2) + return seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, bond_feats +