mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
fix na complex loading
This commit is contained in:
@@ -1273,7 +1273,7 @@ def loader_na_complex(item, Ls, params, native_NA_frac=0.25, negative=False, pic
|
||||
chain_idx[:Ls[0], :Ls[0]] = 1
|
||||
chain_idx[Ls[0]:, Ls[0]:] = 1 # fd - "negatives" still predict DNA double helix
|
||||
bond_feats = torch.zeros((sum(Ls), sum(Ls))).long()
|
||||
bond_feats[:Ls[0], :Ls[0]] = get_protein_bond_feats(L_s[0])
|
||||
bond_feats[:Ls[0], :Ls[0]] = get_protein_bond_feats(Ls[0])
|
||||
bond_feats[Ls[0]:, Ls[0]:] = get_protein_bond_feats(sum(Ls[1:]))
|
||||
|
||||
init = torch.cat((
|
||||
@@ -1470,7 +1470,7 @@ def loader_sm_compl(item, sm_chains, params, pick_top=True):
|
||||
chain_idx = chain_idx[sel][:,sel]
|
||||
bond_feats = bond_feats[sel][:, sel]
|
||||
bond_feats = torch.nn.functional.one_hot(bond_feats, num_classes=NBTYPES)
|
||||
# replace missing with blackholes & conovert NaN to zeros to avoid any NaN problems during loss calculation
|
||||
# replace missing with blackholes & convert NaN to zeros to avoid any NaN problems during loss calculation
|
||||
init = INIT_CRDS.reshape(1, NTOTAL, 3).repeat(xyz.shape[0], xyz.shape[1], 1, 1)
|
||||
xyz = torch.where(mask[...,None], xyz, init).contiguous()
|
||||
xyz = torch.nan_to_num(xyz)
|
||||
|
||||
@@ -173,6 +173,24 @@ class LossTestCase(unittest.TestCase):
|
||||
for i in range(1,5):
|
||||
self.assertLess(fapes[i-1], fapes[i])
|
||||
break
|
||||
with self.subTest("test that FAPE loss over only the atoms can be calculated"):
|
||||
for seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames, bond_feats in self.valid_sm_compl_loader:
|
||||
label_aa_s = msa[:, 0]
|
||||
seq = label_aa_s[:,0].clone()
|
||||
true_crds, atom_mask = resolve_equiv_natives(true_crds[0, 0].unsqueeze(0), true_crds, atom_mask)
|
||||
frames, frame_mask = get_frames(
|
||||
true_crds, atom_mask, seq, frame_indices, atom_frames
|
||||
)
|
||||
|
||||
rotation_mask = is_atom(seq)
|
||||
atom_fape = compute_general_FAPE(
|
||||
true_crds[:,rotation_mask[0],:,:3],
|
||||
true_crds[:,rotation_mask[0],:,:3],
|
||||
atom_mask[:,rotation_mask[0]],
|
||||
frames[:,rotation_mask[0]],
|
||||
frame_mask[:,rotation_mask[0]]
|
||||
)
|
||||
self.assertAlmostEqual(int(atom_fape.numpy()),0)
|
||||
|
||||
def test_get_frames(self):
|
||||
"""test that nodes in atom frames are relatively close to each other (because they should be bonded)"""
|
||||
|
||||
@@ -317,6 +317,15 @@ class Trainer():
|
||||
loss_s.append(l_fape.detach())
|
||||
tot_loss += w_str*l_fape.mean()
|
||||
|
||||
rotation_mask = is_atom(seq)
|
||||
atom_fape = compute_general_FAPE(
|
||||
pred_allatom[:,rotation_mask[0],:,:3],
|
||||
nat_symm[None,rotation_mask[0],:,:3],
|
||||
xs_mask[:,rotation_mask[0]],
|
||||
frames[:,rotation_mask[0]],
|
||||
frame_mask[:,rotation_mask[0]]
|
||||
)
|
||||
loss_s.append(atom_fape.detach())
|
||||
# cart bonded (bond geometry)
|
||||
bond_loss = calc_BB_bond_geom(seq[0], pred_allatom[0:1], idx)
|
||||
if w_bond > 0.0:
|
||||
@@ -520,7 +529,7 @@ class Trainer():
|
||||
self.n_valid_rna = len(valid_rna.keys())
|
||||
self.n_valid_sm_compl = len(valid_sm_compl.keys())
|
||||
|
||||
#self.n_valid_pdb = 4
|
||||
self.n_valid_pdb = 200
|
||||
#self.n_valid_homo = 4
|
||||
#self.n_valid_compl = 4
|
||||
#self.n_valid_neg = 4
|
||||
@@ -579,21 +588,21 @@ class Trainer():
|
||||
loader_pdb, valid_pdb,
|
||||
self.loader_param, homo, p_homo_cut=-1.0
|
||||
)
|
||||
valid_homo_set = Dataset(
|
||||
list(valid_homo.keys())[:self.n_valid_homo],
|
||||
loader_pdb, valid_homo,
|
||||
self.loader_param, homo, p_homo_cut=2.0
|
||||
)
|
||||
valid_compl_set = DatasetComplex(
|
||||
list(valid_compl.keys())[:self.n_valid_compl],
|
||||
loader_complex, valid_compl,
|
||||
self.loader_param, negative=False
|
||||
)
|
||||
valid_neg_set = DatasetComplex(
|
||||
list(valid_neg.keys())[:self.n_valid_neg],
|
||||
loader_complex, valid_neg,
|
||||
self.loader_param, negative=True
|
||||
)
|
||||
# valid_homo_set = Dataset(
|
||||
# list(valid_homo.keys())[:self.n_valid_homo],
|
||||
# loader_pdb, valid_homo,
|
||||
# self.loader_param, homo, p_homo_cut=2.0
|
||||
# )
|
||||
# valid_compl_set = DatasetComplex(
|
||||
# list(valid_compl.keys())[:self.n_valid_compl],
|
||||
# loader_complex, valid_compl,
|
||||
# self.loader_param, negative=False
|
||||
# )
|
||||
# valid_neg_set = DatasetComplex(
|
||||
# list(valid_neg.keys())[:self.n_valid_neg],
|
||||
# loader_complex, valid_neg,
|
||||
# self.loader_param, negative=True
|
||||
# )
|
||||
# valid_na_compl_set = DatasetNAComplex(
|
||||
# list(valid_na_compl.keys())[:self.n_valid_na_compl],
|
||||
# loader_na_complex, valid_na_compl,
|
||||
@@ -647,9 +656,9 @@ class Trainer():
|
||||
)
|
||||
|
||||
valid_pdb_sampler = data.distributed.DistributedSampler(valid_pdb_set, num_replicas=world_size, rank=rank)
|
||||
valid_homo_sampler = data.distributed.DistributedSampler(valid_homo_set, num_replicas=world_size, rank=rank)
|
||||
valid_compl_sampler = data.distributed.DistributedSampler(valid_compl_set, num_replicas=world_size, rank=rank)
|
||||
valid_neg_sampler = data.distributed.DistributedSampler(valid_neg_set, num_replicas=world_size, rank=rank)
|
||||
# valid_homo_sampler = data.distributed.DistributedSampler(valid_homo_set, num_replicas=world_size, rank=rank)
|
||||
# valid_compl_sampler = data.distributed.DistributedSampler(valid_compl_set, num_replicas=world_size, rank=rank)
|
||||
# valid_neg_sampler = data.distributed.DistributedSampler(valid_neg_set, num_replicas=world_size, rank=rank)
|
||||
# valid_na_compl_sampler = data.distributed.DistributedSampler(valid_na_compl_set, num_replicas=world_size, rank=rank)
|
||||
# valid_na_neg_sampler = data.distributed.DistributedSampler(valid_na_neg_set, num_replicas=world_size, rank=rank)
|
||||
# valid_na_from_scratch_compl_sampler = data.distributed.DistributedSampler(valid_na_from_scratch_compl_set, num_replicas=world_size, rank=rank)
|
||||
@@ -659,9 +668,9 @@ class Trainer():
|
||||
|
||||
train_loader = data.DataLoader(train_set, sampler=train_sampler, batch_size=self.batch_size, **LOAD_PARAM)
|
||||
valid_pdb_loader = data.DataLoader(valid_pdb_set, sampler=valid_pdb_sampler, **LOAD_PARAM)
|
||||
valid_homo_loader = data.DataLoader(valid_homo_set, sampler=valid_homo_sampler, **LOAD_PARAM)
|
||||
valid_compl_loader = data.DataLoader(valid_compl_set, sampler=valid_compl_sampler, **LOAD_PARAM)
|
||||
valid_neg_loader = data.DataLoader(valid_neg_set, sampler=valid_neg_sampler, **LOAD_PARAM)
|
||||
# valid_homo_loader = data.DataLoader(valid_homo_set, sampler=valid_homo_sampler, **LOAD_PARAM)
|
||||
# valid_compl_loader = data.DataLoader(valid_compl_set, sampler=valid_compl_sampler, **LOAD_PARAM)
|
||||
# valid_neg_loader = data.DataLoader(valid_neg_set, sampler=valid_neg_sampler, **LOAD_PARAM)
|
||||
# valid_na_compl_loader = data.DataLoader(valid_na_compl_set, sampler=valid_na_compl_sampler, **LOAD_PARAM)
|
||||
# valid_na_neg_loader = data.DataLoader(valid_na_neg_set, sampler=valid_na_neg_sampler, **LOAD_PARAM)
|
||||
# valid_na_from_scratch_compl_loader = data.DataLoader(valid_na_from_scratch_compl_set, sampler=valid_na_from_scratch_compl_sampler, **LOAD_PARAM)
|
||||
|
||||
Reference in New Issue
Block a user