diff --git a/models/rfd3/src/rfd3/transforms/na_geom.py b/models/rfd3/src/rfd3/transforms/na_geom.py index f1f400a..ba45c7f 100644 --- a/models/rfd3/src/rfd3/transforms/na_geom.py +++ b/models/rfd3/src/rfd3/transforms/na_geom.py @@ -179,39 +179,44 @@ class CalculateNucleicAcidGeomFeats(Transform): if not self.is_inference: # First, annotate as usual - atom_array = annotate_na_ss(atom_array, + is_nucleic_ss_example, give_partial_feats = self._sample_training_flags() + + if is_nucleic_ss_example: + + atom_array = annotate_na_ss(atom_array, NA_only=self.NA_only, planar_only=self.planar_only, p_canonical_bp_filter=self.p_canonical_bp_filter, ) - # Generate symmetric partner annotations at the token level for masking purposes. - # choice for object-consistency: if already masked/undefined: be a list mapping to self-index. - partner_sym_map = { - i: atom_array.bp_partners[ts_i] if atom_array.bp_partners[ts_i] is not None else [i] - for i, ts_i in enumerate(token_starts) - } + # Generate symmetric partner annotations at the token level for masking purposes. + # choice for object-consistency: if already masked/undefined: be a list mapping to self-index. + partner_sym_map = { + i: atom_array.bp_partners[ts_i] if atom_array.bp_partners[ts_i] is not None else [i] + for i, ts_i in enumerate(token_starts) + } - # # Sample mask on token level: - is_nucleic_ss_example, give_partial_feats = self._sample_training_flags() - token_mask_to_show = self._sample_where_to_show_ss( - n_tokens, - is_nucleic_ss_example=is_nucleic_ss_example, - give_partial_feats=give_partial_feats, - partner_sym_map=partner_sym_map, - ) # Mask vec for tokens where ss shown + # # Sample mask on token level: + token_mask_to_show = self._sample_where_to_show_ss( + n_tokens, + is_nucleic_ss_example=is_nucleic_ss_example, + give_partial_feats=give_partial_feats, + partner_sym_map=partner_sym_map, + ) # Mask vec for tokens where ss shown - # Spread mask to atom level - is_ss_shown = spread_token_wise(atom_array, token_mask_to_show) - - # Extract the base pair annotations - bp_partners_atom = atom_array.get_annotation("bp_partners") + # Spread mask to atom level + is_ss_shown = spread_token_wise(atom_array, token_mask_to_show) + + # Extract the base pair annotations + bp_partners_atom = atom_array.get_annotation("bp_partners") - # Remove unshown positions from bp_partners annotation - bp_partners_atom[~is_ss_shown] = None - - # Reset the annotation with newly hidden positions - atom_array.set_annotation("bp_partners", bp_partners_atom) + # Remove unshown positions from bp_partners annotation + bp_partners_atom[~is_ss_shown] = None + + # Reset the annotation with newly hidden positions + atom_array.set_annotation("bp_partners", bp_partners_atom) + else: + atom_array.set_annotation("bp_partners", np.array([None]*len(atom_array))) # Inference case: create from commandline args else: