mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
dev on atom protein representations
This commit is contained in:
@@ -5,6 +5,7 @@ from torch.utils import data
|
||||
# from chemical import NFRAMES
|
||||
from data_loader import get_train_valid_set, Dataset, DatasetNAComplex, DatasetRNA, DatasetSMComplex, loader_pdb, loader_na_complex, loader_rna, loader_sm_compl,set_data_loader_params
|
||||
from kinematics import xyz_to_c6d, xyz_to_t2d
|
||||
from chemical import num2aa, aa2elt, aa2num
|
||||
from loss import compute_general_FAPE, resolve_equiv_natives, calc_str_loss
|
||||
from util import get_frames, frame_indices, is_atom, xyz_to_frame_xyz, xyz_t_to_frame_xyz
|
||||
|
||||
@@ -250,6 +251,81 @@ class LossTestCase(unittest.TestCase):
|
||||
self.assertEqual(res_mask.shape[0], B)
|
||||
self.assertEqual(res_mask.shape[1], L)
|
||||
break
|
||||
|
||||
|
||||
class DataLoaderTestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.loader_param = set_data_loader_params({})
|
||||
(
|
||||
pdb_items, fb_items, compl_items, neg_items, na_compl_items, na_neg_items, rna_items,
|
||||
sm_compl_items, valid_pdb, valid_homo, valid_compl, valid_neg, valid_na_compl,
|
||||
valid_na_neg, valid_rna, valid_sm_compl, homo
|
||||
) = get_train_valid_set(self.loader_param)
|
||||
|
||||
pdb_IDs, pdb_weights, pdb_dict = pdb_items
|
||||
na_compl_IDs, na_compl_weights, na_compl_dict = na_compl_items
|
||||
rna_IDs, rna_weights, rna_dict = rna_items
|
||||
sm_compl_IDs, sm_compl_weights, sm_compl_dict = sm_compl_items
|
||||
self.homo = homo
|
||||
|
||||
valid_pdb_set = Dataset(
|
||||
list(valid_pdb.keys()),
|
||||
loader_pdb, valid_pdb,
|
||||
self.loader_param, homo, p_homo_cut=-1.0
|
||||
)
|
||||
self.valid_pdb_loader = data.DataLoader(valid_pdb_set)
|
||||
|
||||
def test_vaporize_protein(self):
|
||||
for seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames, bond_feats in self.valid_pdb_loader:
|
||||
B, L = msa[:, 0, 0].shape
|
||||
i = torch.randint(2, L-3,(1,))
|
||||
residues_atomize = msa[:, 0, 0, i-2:i+2]
|
||||
residues_atomize = [aa2elt[num][:14] for num in residues_atomize[0]]
|
||||
lig_seq = []
|
||||
ra = []
|
||||
for idx in range(len(residues_atomize)):
|
||||
for jdx in range(14):
|
||||
if residues_atomize[idx][jdx]:
|
||||
ra.append((i-2+idx, jdx))
|
||||
lig_seq.append(aa2num[residues_atomize[idx][jdx]])
|
||||
lig_seq = torch.tensor(lig_seq)
|
||||
ins = torch.zeros_like(lig_seq)
|
||||
print(lig_seq)
|
||||
ra = torch.tensor(ra)
|
||||
r,a = ra.T
|
||||
lig_xyz = torch.zeros((len(ra), 3))
|
||||
lig_xyz = true_crds[:, r, a]
|
||||
lig_mask = atom_mask[:, r, a]
|
||||
# print(lig_xyz)
|
||||
# print(lig_mask)
|
||||
#NEED TO FIGURE OUT XYZ_T, T1D set everything vaporized into NaN and then create new NaN features for length, set t1D into gaps
|
||||
msa = torch.cat((msa[:, :, :, :i-2], msa[:, :, :, i+2:]), dim=3)
|
||||
msa_masked = torch.cat((msa_masked[:, :, :, :i-2], msa_masked[:, :, :, i+2:]), dim=3)
|
||||
msa_full = torch.cat((msa_full[:, :, :, :i-2], msa_full[:, :, :, i+2:]), dim=3)
|
||||
mask_msa = torch.cat((mask_msa[:, :, :, :i-2], mask_msa[:, :, :, i+2:]), dim=3)
|
||||
|
||||
true_crds = torch.cat((true_crds[ :, :i-2], true_crds[ :, i+2:]), dim=1)
|
||||
atom_mask = torch.cat((atom_mask[ :, :i-2], atom_mask[ :, i+2:]), dim=1)
|
||||
|
||||
idx_pdb = torch.cat((idx_pdb[ :, :i-2], idx_pdb[ :, i+2:]), dim=1)
|
||||
print(idx_pdb.shape)
|
||||
xyz_t = torch.cat((xyz_t[ :, :, :i-2], xyz_t[:, :, i+2:]), dim=2)
|
||||
print(xyz_t.shape)
|
||||
t1d = torch.cat((t1d[ :, :, :i-2], t1d[:, :, i+2:]), dim=2)
|
||||
print(t1d.shape)
|
||||
xyz_prev = torch.cat((xyz_prev[ :, :i-2], xyz_prev[:, i+2:]), dim=1)
|
||||
print(xyz_prev.shape)
|
||||
same_chain = torch.cat((same_chain[ :, :i-2], same_chain[:, i+2:]), dim=1)
|
||||
same_chain = torch.cat((same_chain[ :, :, :i-2], same_chain[:, :, i+2:]), dim=2)
|
||||
print(same_chain.shape)
|
||||
|
||||
print(bond_feats.shape)
|
||||
bond_feats = torch.cat((bond_feats[ :, :i-2], bond_feats[:, i+2:]), dim=1)
|
||||
bond_feats = torch.cat((bond_feats[ :, :, :i-2], bond_feats[:, :, i+2:]), dim=2)
|
||||
print(bond_feats.shape)
|
||||
break
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -858,7 +858,6 @@ def get_atom_frames(msa, mol, G):
|
||||
assert msa.shape[0] == len(selected_frames)
|
||||
return torch.tensor(selected_frames).long()
|
||||
|
||||
|
||||
### Generate bond features for small molecules ###
|
||||
def get_bond_feats(mol, G):
|
||||
"""creates 2d bond graph for small molecules"""
|
||||
@@ -878,3 +877,44 @@ def get_protein_bond_feats(protein_L):
|
||||
bond_feats[residues, residues+1] = 1
|
||||
bond_feats[residues+1, residues] = 1
|
||||
return bond_feats
|
||||
|
||||
def atomize_protein(i, msa, true_crds, atom_mask):
|
||||
""" given an index i, make the preceding 2 residues and the following residue (4 total) into "atom" nodes """
|
||||
residues_atomize = msa[0, 0, i-2:i+2]
|
||||
residues_atomize = [aa2elt[num][:14] for num in residues_atomize[0]]
|
||||
lig_seq = []
|
||||
ra = []
|
||||
for idx in range(len(residues_atomize)):
|
||||
for jdx in range(14):
|
||||
if residues_atomize[idx][jdx]:
|
||||
ra.append((i-2+idx, jdx))
|
||||
lig_seq.append(aa2num[residues_atomize[idx][jdx]])
|
||||
lig_seq = torch.tensor(lig_seq)
|
||||
ins = torch.zeros_like(lig_seq)
|
||||
ra = torch.tensor(ra)
|
||||
r,a = ra.T
|
||||
lig_xyz = torch.zeros((len(ra), 3))
|
||||
lig_xyz = true_crds[r, a]
|
||||
lig_mask = atom_mask[r, a]
|
||||
return lig_seq, ins, lig_xyz, lig_mask
|
||||
|
||||
def remove_protein_info(i, seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, bond_feats):
|
||||
""" remove the msa/template information for the portion of the protein that was "atomized" """
|
||||
seq = torch.cat((seq[:, :i-2], seq[:, i+2:]), dim=1)
|
||||
msa = torch.cat((msa[:, :, :i-2], msa[:, :, i+2:]), dim=2)
|
||||
msa_masked = torch.cat((msa_masked[:, :, :, :i-2], msa_masked[:, :, i+2:]), dim=2)
|
||||
msa_full = torch.cat((msa_full[:, :, :, :i-2], msa_full[:, :, i+2:]), dim=2)
|
||||
mask_msa = torch.cat((mask_msa[:, :, :, :i-2], mask_msa[:, :, i+2:]), dim=2)
|
||||
true_crds = torch.cat((true_crds[ :, :i-2], true_crds[ :, i+2:]), dim=1)
|
||||
atom_mask = torch.cat((atom_mask[ :, :i-2], atom_mask[ :, i+2:]), dim=1)
|
||||
|
||||
idx_pdb = torch.cat((idx_pdb[ :, :i-2], idx_pdb[ :, i+2:]), dim=1)
|
||||
xyz_t = torch.cat((xyz_t[ :, :, :i-2], xyz_t[:, :, i+2:]), dim=2)
|
||||
t1d = torch.cat((t1d[ :, :, :i-2], t1d[:, :, i+2:]), dim=2)
|
||||
xyz_prev = torch.cat((xyz_prev[ :, :i-2], xyz_prev[:, i+2:]), dim=1)
|
||||
same_chain = torch.cat((same_chain[ :i-2], same_chain[:, i+2:]), dim=1)
|
||||
same_chain = torch.cat((same_chain[ :, :i-2], same_chain[:, i+2:]), dim=2)
|
||||
bond_feats = torch.cat((bond_feats[ :i-2], bond_feats[i+2:]), dim=1)
|
||||
bond_feats = torch.cat((bond_feats[ :, :i-2], bond_feats[:, i+2:]), dim=2)
|
||||
return seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, bond_feats
|
||||
|
||||
|
||||
Reference in New Issue
Block a user