diff --git a/models/rf3/src/rf3/loss/af3_losses.py b/models/rf3/src/rf3/loss/af3_losses.py index 1bb0cd2..974db26 100644 --- a/models/rf3/src/rf3/loss/af3_losses.py +++ b/models/rf3/src/rf3/loss/af3_losses.py @@ -349,13 +349,9 @@ class SubunitSymmetryResolution(nn.Module): x_native = symm_input["coord_atom_lvl"].to(x_pred.device) mask_native = symm_input["mask_atom_lvl"].to(x_pred.device) - try: - x_native_aln, x_native_mask = self._resolve_subunits( - mol_entities, mol_iid, crop_mask, x_native, mask_native, x_pred - ) - except Exception: - # fd ... TO DO: DEBUG! - return loss_input + x_native_aln, x_native_mask = self._resolve_subunits( + mol_entities, mol_iid, crop_mask, x_native, mask_native, x_pred + ) loss_input["X_gt_L"] = x_native_aln loss_input["crd_mask_L"] = x_native_mask