From 3307b89e17d9fc035b6998e70fdc342a04b92c5e Mon Sep 17 00:00:00 2001 From: Rohith Krishna Date: Fri, 15 Jul 2022 08:32:39 -0700 Subject: [PATCH] added blackhole for missing coords --- RF2_allatom/data_loader.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/RF2_allatom/data_loader.py b/RF2_allatom/data_loader.py index a7558fd..2d4bbbc 100644 --- a/RF2_allatom/data_loader.py +++ b/RF2_allatom/data_loader.py @@ -1466,9 +1466,9 @@ def loader_sm_compl(item, sm_chains, params, pick_top=True): bond_feats = bond_feats[sel][:, sel] bond_feats = torch.nn.functional.one_hot(bond_feats, num_classes=NBTYPES) # replace missing with blackholes & conovert NaN to zeros to avoid any NaN problems during loss calculation - # init = INIT_CRDS.reshape(1, NTOTAL, 3).repeat(len(xyz), 1, 1) - # xyz = torch.where(mask[...,None], xyz, init).contiguous() - # xyz = torch.nan_to_num(xyz) + init = INIT_CRDS.reshape(1, NTOTAL, 3).repeat(len(xyz), 1, 1) + xyz = torch.where(mask[...,None], xyz, init).contiguous() + xyz = torch.nan_to_num(xyz) return seq.long(), msa_seed_orig.long(), msa_seed.float(), msa_extra.float(), mask_msa,\ xyz.float(), mask, idx.long(), \