fixed interface lddt loss

This commit is contained in:
Rohith Krishna
2022-08-02 12:48:40 -07:00
parent bac7bc3d8b
commit a7f7f41e4f
3 changed files with 15 additions and 4 deletions

View File

@@ -778,7 +778,9 @@ def calc_allatom_lddt_loss(P, Q, pred_lddt, idx, atm_mask, mask_2d, same_chain,
if negative:
# ignore atoms between different chains
pair_mask *= same_chain.bool()[:,:,:,None,None]
elif interface:
# ignore atoms between the same chain
pair_mask *= ~same_chain.bool()[:,:,:,None,None]
delta_PQ = torch.abs(Pij-Qij+eps) # (N, L, L, 14, 14)
lddt = torch.zeros( (N,L,Natm), device=P.device ) # (N, L, 27)

View File

@@ -249,7 +249,16 @@ class LossTestCase(unittest.TestCase):
B, L = true_crds.shape[:2]
self.assertEqual(res_mask.shape[0], B)
self.assertEqual(res_mask.shape[1], L)
print(is_atom(msa[:,0,0]))
print(res_mask)
print(true_crds[res_mask][:,:23])
break
def test_sm_coords(self):
""""""
for seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames, bond_feats in self.valid_sm_compl_loader:
print(true_crds.shape)
print(true_crds[..., 1, :])
if __name__ == '__main__':
unittest.main()

View File

@@ -348,10 +348,10 @@ class Trainer():
chain2 = torch.zeros_like(same_chain, dtype=bool)
chain2[:,L0:,L0:] = True
_, allatom_lddt_c2 = calc_allatom_lddt_loss(
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, mask_2d, chain2, negative=True)
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, mask_2d, chain2, negative=True, bin_scaling=0.5)
loss_s.append(allatom_lddt_c2.detach())
_, allatom_lddt_inter = calc_allatom_lddt_loss(
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, mask_2d, same_chain, interface=True, bin_scaling=0.5)
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, mask_2d, same_chain, interface=True)
loss_s.append(allatom_lddt_inter.detach())
# hbond [use all atoms not just those in native]
#hb_loss = calc_hb(