added blackhole for missing coords

This commit is contained in:
Rohith Krishna
2022-07-15 08:32:39 -07:00
parent d45610cb93
commit 3307b89e17

View File

@@ -1466,9 +1466,9 @@ def loader_sm_compl(item, sm_chains, params, pick_top=True):
bond_feats = bond_feats[sel][:, sel]
bond_feats = torch.nn.functional.one_hot(bond_feats, num_classes=NBTYPES)
# replace missing with blackholes & conovert NaN to zeros to avoid any NaN problems during loss calculation
# init = INIT_CRDS.reshape(1, NTOTAL, 3).repeat(len(xyz), 1, 1)
# xyz = torch.where(mask[...,None], xyz, init).contiguous()
# xyz = torch.nan_to_num(xyz)
init = INIT_CRDS.reshape(1, NTOTAL, 3).repeat(len(xyz), 1, 1)
xyz = torch.where(mask[...,None], xyz, init).contiguous()
xyz = torch.nan_to_num(xyz)
return seq.long(), msa_seed_orig.long(), msa_seed.float(), msa_extra.float(), mask_msa,\
xyz.float(), mask, idx.long(), \