mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
fixed interface lddt loss
This commit is contained in:
@@ -778,7 +778,9 @@ def calc_allatom_lddt_loss(P, Q, pred_lddt, idx, atm_mask, mask_2d, same_chain,
|
||||
if negative:
|
||||
# ignore atoms between different chains
|
||||
pair_mask *= same_chain.bool()[:,:,:,None,None]
|
||||
|
||||
elif interface:
|
||||
# ignore atoms between the same chain
|
||||
pair_mask *= ~same_chain.bool()[:,:,:,None,None]
|
||||
delta_PQ = torch.abs(Pij-Qij+eps) # (N, L, L, 14, 14)
|
||||
|
||||
lddt = torch.zeros( (N,L,Natm), device=P.device ) # (N, L, 27)
|
||||
|
||||
@@ -249,7 +249,16 @@ class LossTestCase(unittest.TestCase):
|
||||
B, L = true_crds.shape[:2]
|
||||
self.assertEqual(res_mask.shape[0], B)
|
||||
self.assertEqual(res_mask.shape[1], L)
|
||||
print(is_atom(msa[:,0,0]))
|
||||
print(res_mask)
|
||||
print(true_crds[res_mask][:,:23])
|
||||
break
|
||||
|
||||
|
||||
def test_sm_coords(self):
|
||||
""""""
|
||||
for seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames, bond_feats in self.valid_sm_compl_loader:
|
||||
print(true_crds.shape)
|
||||
print(true_crds[..., 1, :])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -348,10 +348,10 @@ class Trainer():
|
||||
chain2 = torch.zeros_like(same_chain, dtype=bool)
|
||||
chain2[:,L0:,L0:] = True
|
||||
_, allatom_lddt_c2 = calc_allatom_lddt_loss(
|
||||
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, mask_2d, chain2, negative=True)
|
||||
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, mask_2d, chain2, negative=True, bin_scaling=0.5)
|
||||
loss_s.append(allatom_lddt_c2.detach())
|
||||
_, allatom_lddt_inter = calc_allatom_lddt_loss(
|
||||
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, mask_2d, same_chain, interface=True, bin_scaling=0.5)
|
||||
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, mask_2d, same_chain, interface=True)
|
||||
loss_s.append(allatom_lddt_inter.detach())
|
||||
# hbond [use all atoms not just those in native]
|
||||
#hb_loss = calc_hb(
|
||||
|
||||
Reference in New Issue
Block a user