figuring out symmetric configurations

This commit is contained in:
Rohith Krishna
2022-08-03 16:44:44 -07:00
parent 128d1b420e
commit 8e8e1a4eae
2 changed files with 28 additions and 18 deletions

View File

@@ -7,7 +7,7 @@ from data_loader import get_train_valid_set, Dataset, DatasetNAComplex, DatasetR
from kinematics import xyz_to_c6d, xyz_to_t2d
from chemical import num2aa, aa2elt, aa2num
from loss import compute_general_FAPE, resolve_equiv_natives, calc_str_loss
from util import get_frames, frame_indices, is_atom, xyz_to_frame_xyz, xyz_t_to_frame_xyz
from util import get_frames, frame_indices, is_atom, xyz_to_frame_xyz, xyz_t_to_frame_xyz, long2alt
class LossTestCase(unittest.TestCase):
@@ -244,6 +244,7 @@ class LossTestCase(unittest.TestCase):
break
def test_res_mask(self):
"""updated res_mask to not mask "atom" nodes that only have one backbone atom filled in """
for seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames, bond_feats in self.valid_sm_compl_loader:
true_crds, atom_mask = resolve_equiv_natives(true_crds[0, 0].unsqueeze(0), true_crds, atom_mask)
res_mask = ~((atom_mask[:,:,:3].sum(dim=-1) < 3.0) * ~(is_atom(msa[:,0,0])))
@@ -252,6 +253,12 @@ class LossTestCase(unittest.TestCase):
self.assertEqual(res_mask.shape[1], L)
break
def test_resolve_equiv_natives(self):
""" test that resolve_equiv_natives works"""
for seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames, bond_feats in self.valid_sm_compl_loader:
true_crds, atom_mask = resolve_equiv_natives(torch.randn(true_crds[0, 0].unsqueeze(0).shape), true_crds, atom_mask)
print(true_crds)
break
class DataLoaderTestCase(unittest.TestCase):
@@ -283,6 +290,13 @@ class DataLoaderTestCase(unittest.TestCase):
i = torch.randint(2, L-3,(1,))
residues_atomize = msa[:, 0, 0, i-2:i+2]
residues_atomize = [aa2elt[num][:14] for num in residues_atomize[0]]
true_alt = torch.zeros_like(true_crds)
true_alt.scatter_(2, long2alt[msa[:, 0, 0],:,None].repeat(1,1,1,3), true_crds)
print(true_crds.shape)
print((true_crds[:, i-2:i+2] == true_alt[:, i-2:i+2]).all(dim=2).all(dim=2).squeeze())
print(torch.nonzero(~(true_crds[:, i-2:i+2] == true_alt[:, i-2:i+2]).all(dim=2).all(dim=2).squeeze()).squeeze())
lig_seq = []
ra = []
for idx in range(len(residues_atomize)):
@@ -310,21 +324,16 @@ class DataLoaderTestCase(unittest.TestCase):
atom_mask = torch.cat((atom_mask[ :, :i-2], atom_mask[ :, i+2:]), dim=1)
idx_pdb = torch.cat((idx_pdb[ :, :i-2], idx_pdb[ :, i+2:]), dim=1)
print(idx_pdb.shape)
xyz_t = torch.cat((xyz_t[ :, :, :i-2], xyz_t[:, :, i+2:]), dim=2)
print(xyz_t.shape)
t1d = torch.cat((t1d[ :, :, :i-2], t1d[:, :, i+2:]), dim=2)
print(t1d.shape)
xyz_prev = torch.cat((xyz_prev[ :, :i-2], xyz_prev[:, i+2:]), dim=1)
print(xyz_prev.shape)
same_chain = torch.cat((same_chain[ :, :i-2], same_chain[:, i+2:]), dim=1)
same_chain = torch.cat((same_chain[ :, :, :i-2], same_chain[:, :, i+2:]), dim=2)
print(same_chain.shape)
print(bond_feats.shape)
bond_feats = torch.cat((bond_feats[ :, :i-2], bond_feats[:, i+2:]), dim=1)
bond_feats = torch.cat((bond_feats[ :, :, :i-2], bond_feats[:, :, i+2:]), dim=2)
print(bond_feats.shape)
break
if __name__ == '__main__':

View File

@@ -22,7 +22,7 @@ from scheduler import get_linear_schedule_with_warmup, get_stepwise_decay_schedu
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
#torch.autograd.set_detect_anomaly(True)
torch.autograd.set_detect_anomaly(True)
torch.manual_seed(5924)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
@@ -175,7 +175,7 @@ class Trainer():
loss_s = list()
tot_loss = 0.0
# c6d loss
for i in range(4):
loss = self.loss_fn(logit_s[i], label_s[...,i]) # (B, L, L)
@@ -243,6 +243,7 @@ class Trainer():
# get alternative coordinates for ground-truth
true_alt = torch.zeros_like(true)
true_alt.scatter_(2, self.l2a[seq,:,None].repeat(1,1,1,3), true)
print(true_alt)
natRs_all, _n0 = self.compute_allatom_coords(seq, true[...,:3,:], true_tors)
natRs_all_alt, _n1 = self.compute_allatom_coords(seq, true_alt[...,:3,:], true_tors_alt)
predTs = pred[-1,...]
@@ -348,10 +349,10 @@ class Trainer():
chain2 = torch.zeros_like(same_chain, dtype=bool)
chain2[:,L0:,L0:] = True
_, allatom_lddt_c2 = calc_allatom_lddt_loss(
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, mask_2d, chain2, negative=True)
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, mask_2d, chain2, negative=True, bin_scaling=0.5)
loss_s.append(allatom_lddt_c2.detach())
_, allatom_lddt_inter = calc_allatom_lddt_loss(
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, mask_2d, same_chain, interface=True, bin_scaling=0.5)
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, mask_2d, same_chain, interface=True)
loss_s.append(allatom_lddt_inter.detach())
# hbond [use all atoms not just those in native]
#hb_loss = calc_hb(
@@ -733,7 +734,7 @@ class Trainer():
# load model
loaded_epoch, best_valid_loss = self.load_model(ddp_model, optimizer, scheduler, scaler,
self.model_name, gpu, resume_train=True)
self.model_name, gpu, suffix="best", resume_train=True)
if (self.eval):
# run protein/NA prediction (TEMPLATED)
@@ -742,13 +743,13 @@ class Trainer():
# rank, gpu, world_size, 0, header="NA", report_interface=False, verbose=True)
# run protein/NA prediction (NON-TEMPLATED)
_, _, _ = self.valid_ppi_cycle(
ddp_model, valid_na_from_scratch_compl_loader, valid_na_from_scratch_neg_loader,
rank, gpu, world_size, 0, header="NA", report_interface=False, verbose=True)
#_, _, _ = self.valid_ppi_cycle(
# ddp_model, valid_na_from_scratch_compl_loader, valid_na_from_scratch_neg_loader,
# rank, gpu, world_size, 0, header="NA", report_interface=False, verbose=True)
# run RNA prediction
#_,_,_ = self.valid_pdb_cycle(ddp_model, valid_rna_loader, rank, gpu, world_size, 0, verbose=True)
_, _, _ = self.valid_pdb_cycle(ddp_model, valid_sm_compl_loader, rank, gpu, world_size, 0, verbose=True)
dist.destroy_process_group()
return
@@ -816,7 +817,7 @@ class Trainer():
'valid_loss': valid_loss,
'valid_acc': valid_acc,
'best_loss': best_valid_loss},
self.checkpoint_fn(self.model_name, 'last'))
self.checkpoint_fn(self.model_name, str(epoch)))
dist.destroy_process_group()