mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
added blackhole for missing coords
This commit is contained in:
@@ -1466,9 +1466,9 @@ def loader_sm_compl(item, sm_chains, params, pick_top=True):
|
||||
bond_feats = bond_feats[sel][:, sel]
|
||||
bond_feats = torch.nn.functional.one_hot(bond_feats, num_classes=NBTYPES)
|
||||
# replace missing with blackholes & conovert NaN to zeros to avoid any NaN problems during loss calculation
|
||||
# init = INIT_CRDS.reshape(1, NTOTAL, 3).repeat(len(xyz), 1, 1)
|
||||
# xyz = torch.where(mask[...,None], xyz, init).contiguous()
|
||||
# xyz = torch.nan_to_num(xyz)
|
||||
init = INIT_CRDS.reshape(1, NTOTAL, 3).repeat(len(xyz), 1, 1)
|
||||
xyz = torch.where(mask[...,None], xyz, init).contiguous()
|
||||
xyz = torch.nan_to_num(xyz)
|
||||
|
||||
return seq.long(), msa_seed_orig.long(), msa_seed.float(), msa_extra.float(), mask_msa,\
|
||||
xyz.float(), mask, idx.long(), \
|
||||
|
||||
Reference in New Issue
Block a user