fix na complex loading

This commit is contained in:
Rohith Krishna
2022-07-19 11:34:55 -07:00
parent 94f6d81176
commit ea737a454d
3 changed files with 51 additions and 24 deletions

View File

@@ -1273,7 +1273,7 @@ def loader_na_complex(item, Ls, params, native_NA_frac=0.25, negative=False, pic
chain_idx[:Ls[0], :Ls[0]] = 1
chain_idx[Ls[0]:, Ls[0]:] = 1 # fd - "negatives" still predict DNA double helix
bond_feats = torch.zeros((sum(Ls), sum(Ls))).long()
bond_feats[:Ls[0], :Ls[0]] = get_protein_bond_feats(L_s[0])
bond_feats[:Ls[0], :Ls[0]] = get_protein_bond_feats(Ls[0])
bond_feats[Ls[0]:, Ls[0]:] = get_protein_bond_feats(sum(Ls[1:]))
init = torch.cat((
@@ -1470,7 +1470,7 @@ def loader_sm_compl(item, sm_chains, params, pick_top=True):
chain_idx = chain_idx[sel][:,sel]
bond_feats = bond_feats[sel][:, sel]
bond_feats = torch.nn.functional.one_hot(bond_feats, num_classes=NBTYPES)
# replace missing with blackholes & conovert NaN to zeros to avoid any NaN problems during loss calculation
# replace missing with blackholes & convert NaN to zeros to avoid any NaN problems during loss calculation
init = INIT_CRDS.reshape(1, NTOTAL, 3).repeat(xyz.shape[0], xyz.shape[1], 1, 1)
xyz = torch.where(mask[...,None], xyz, init).contiguous()
xyz = torch.nan_to_num(xyz)

View File

@@ -173,6 +173,24 @@ class LossTestCase(unittest.TestCase):
for i in range(1,5):
self.assertLess(fapes[i-1], fapes[i])
break
with self.subTest("test that FAPE loss over only the atoms can be calculated"):
for seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames, bond_feats in self.valid_sm_compl_loader:
label_aa_s = msa[:, 0]
seq = label_aa_s[:,0].clone()
true_crds, atom_mask = resolve_equiv_natives(true_crds[0, 0].unsqueeze(0), true_crds, atom_mask)
frames, frame_mask = get_frames(
true_crds, atom_mask, seq, frame_indices, atom_frames
)
rotation_mask = is_atom(seq)
atom_fape = compute_general_FAPE(
true_crds[:,rotation_mask[0],:,:3],
true_crds[:,rotation_mask[0],:,:3],
atom_mask[:,rotation_mask[0]],
frames[:,rotation_mask[0]],
frame_mask[:,rotation_mask[0]]
)
self.assertAlmostEqual(int(atom_fape.numpy()),0)
def test_get_frames(self):
"""test that nodes in atom frames are relatively close to each other (because they should be bonded)"""

View File

@@ -317,6 +317,15 @@ class Trainer():
loss_s.append(l_fape.detach())
tot_loss += w_str*l_fape.mean()
rotation_mask = is_atom(seq)
atom_fape = compute_general_FAPE(
pred_allatom[:,rotation_mask[0],:,:3],
nat_symm[None,rotation_mask[0],:,:3],
xs_mask[:,rotation_mask[0]],
frames[:,rotation_mask[0]],
frame_mask[:,rotation_mask[0]]
)
loss_s.append(atom_fape.detach())
# cart bonded (bond geometry)
bond_loss = calc_BB_bond_geom(seq[0], pred_allatom[0:1], idx)
if w_bond > 0.0:
@@ -520,7 +529,7 @@ class Trainer():
self.n_valid_rna = len(valid_rna.keys())
self.n_valid_sm_compl = len(valid_sm_compl.keys())
#self.n_valid_pdb = 4
self.n_valid_pdb = 200
#self.n_valid_homo = 4
#self.n_valid_compl = 4
#self.n_valid_neg = 4
@@ -579,21 +588,21 @@ class Trainer():
loader_pdb, valid_pdb,
self.loader_param, homo, p_homo_cut=-1.0
)
valid_homo_set = Dataset(
list(valid_homo.keys())[:self.n_valid_homo],
loader_pdb, valid_homo,
self.loader_param, homo, p_homo_cut=2.0
)
valid_compl_set = DatasetComplex(
list(valid_compl.keys())[:self.n_valid_compl],
loader_complex, valid_compl,
self.loader_param, negative=False
)
valid_neg_set = DatasetComplex(
list(valid_neg.keys())[:self.n_valid_neg],
loader_complex, valid_neg,
self.loader_param, negative=True
)
# valid_homo_set = Dataset(
# list(valid_homo.keys())[:self.n_valid_homo],
# loader_pdb, valid_homo,
# self.loader_param, homo, p_homo_cut=2.0
# )
# valid_compl_set = DatasetComplex(
# list(valid_compl.keys())[:self.n_valid_compl],
# loader_complex, valid_compl,
# self.loader_param, negative=False
# )
# valid_neg_set = DatasetComplex(
# list(valid_neg.keys())[:self.n_valid_neg],
# loader_complex, valid_neg,
# self.loader_param, negative=True
# )
# valid_na_compl_set = DatasetNAComplex(
# list(valid_na_compl.keys())[:self.n_valid_na_compl],
# loader_na_complex, valid_na_compl,
@@ -647,9 +656,9 @@ class Trainer():
)
valid_pdb_sampler = data.distributed.DistributedSampler(valid_pdb_set, num_replicas=world_size, rank=rank)
valid_homo_sampler = data.distributed.DistributedSampler(valid_homo_set, num_replicas=world_size, rank=rank)
valid_compl_sampler = data.distributed.DistributedSampler(valid_compl_set, num_replicas=world_size, rank=rank)
valid_neg_sampler = data.distributed.DistributedSampler(valid_neg_set, num_replicas=world_size, rank=rank)
# valid_homo_sampler = data.distributed.DistributedSampler(valid_homo_set, num_replicas=world_size, rank=rank)
# valid_compl_sampler = data.distributed.DistributedSampler(valid_compl_set, num_replicas=world_size, rank=rank)
# valid_neg_sampler = data.distributed.DistributedSampler(valid_neg_set, num_replicas=world_size, rank=rank)
# valid_na_compl_sampler = data.distributed.DistributedSampler(valid_na_compl_set, num_replicas=world_size, rank=rank)
# valid_na_neg_sampler = data.distributed.DistributedSampler(valid_na_neg_set, num_replicas=world_size, rank=rank)
# valid_na_from_scratch_compl_sampler = data.distributed.DistributedSampler(valid_na_from_scratch_compl_set, num_replicas=world_size, rank=rank)
@@ -659,9 +668,9 @@ class Trainer():
train_loader = data.DataLoader(train_set, sampler=train_sampler, batch_size=self.batch_size, **LOAD_PARAM)
valid_pdb_loader = data.DataLoader(valid_pdb_set, sampler=valid_pdb_sampler, **LOAD_PARAM)
valid_homo_loader = data.DataLoader(valid_homo_set, sampler=valid_homo_sampler, **LOAD_PARAM)
valid_compl_loader = data.DataLoader(valid_compl_set, sampler=valid_compl_sampler, **LOAD_PARAM)
valid_neg_loader = data.DataLoader(valid_neg_set, sampler=valid_neg_sampler, **LOAD_PARAM)
# valid_homo_loader = data.DataLoader(valid_homo_set, sampler=valid_homo_sampler, **LOAD_PARAM)
# valid_compl_loader = data.DataLoader(valid_compl_set, sampler=valid_compl_sampler, **LOAD_PARAM)
# valid_neg_loader = data.DataLoader(valid_neg_set, sampler=valid_neg_sampler, **LOAD_PARAM)
# valid_na_compl_loader = data.DataLoader(valid_na_compl_set, sampler=valid_na_compl_sampler, **LOAD_PARAM)
# valid_na_neg_loader = data.DataLoader(valid_na_neg_set, sampler=valid_na_neg_sampler, **LOAD_PARAM)
# valid_na_from_scratch_compl_loader = data.DataLoader(valid_na_from_scratch_compl_set, sampler=valid_na_from_scratch_compl_sampler, **LOAD_PARAM)