skip ss computation when not ss example

This commit is contained in:
Raktim Mitra
2026-02-17 11:06:26 -08:00
parent 3dfa04ac20
commit 91a0eb22ec

View File

@@ -179,39 +179,44 @@ class CalculateNucleicAcidGeomFeats(Transform):
if not self.is_inference:
# First, annotate as usual
atom_array = annotate_na_ss(atom_array,
is_nucleic_ss_example, give_partial_feats = self._sample_training_flags()
if is_nucleic_ss_example:
atom_array = annotate_na_ss(atom_array,
NA_only=self.NA_only,
planar_only=self.planar_only,
p_canonical_bp_filter=self.p_canonical_bp_filter,
)
# Generate symmetric partner annotations at the token level for masking purposes.
# choice for object-consistency: if already masked/undefined: be a list mapping to self-index.
partner_sym_map = {
i: atom_array.bp_partners[ts_i] if atom_array.bp_partners[ts_i] is not None else [i]
for i, ts_i in enumerate(token_starts)
}
# Generate symmetric partner annotations at the token level for masking purposes.
# choice for object-consistency: if already masked/undefined: be a list mapping to self-index.
partner_sym_map = {
i: atom_array.bp_partners[ts_i] if atom_array.bp_partners[ts_i] is not None else [i]
for i, ts_i in enumerate(token_starts)
}
# # Sample mask on token level:
is_nucleic_ss_example, give_partial_feats = self._sample_training_flags()
token_mask_to_show = self._sample_where_to_show_ss(
n_tokens,
is_nucleic_ss_example=is_nucleic_ss_example,
give_partial_feats=give_partial_feats,
partner_sym_map=partner_sym_map,
) # Mask vec for tokens where ss shown
# # Sample mask on token level:
token_mask_to_show = self._sample_where_to_show_ss(
n_tokens,
is_nucleic_ss_example=is_nucleic_ss_example,
give_partial_feats=give_partial_feats,
partner_sym_map=partner_sym_map,
) # Mask vec for tokens where ss shown
# Spread mask to atom level
is_ss_shown = spread_token_wise(atom_array, token_mask_to_show)
# Extract the base pair annotations
bp_partners_atom = atom_array.get_annotation("bp_partners")
# Spread mask to atom level
is_ss_shown = spread_token_wise(atom_array, token_mask_to_show)
# Extract the base pair annotations
bp_partners_atom = atom_array.get_annotation("bp_partners")
# Remove unshown positions from bp_partners annotation
bp_partners_atom[~is_ss_shown] = None
# Reset the annotation with newly hidden positions
atom_array.set_annotation("bp_partners", bp_partners_atom)
# Remove unshown positions from bp_partners annotation
bp_partners_atom[~is_ss_shown] = None
# Reset the annotation with newly hidden positions
atom_array.set_annotation("bp_partners", bp_partners_atom)
else:
atom_array.set_annotation("bp_partners", np.array([None]*len(atom_array)))
# Inference case: create from commandline args
else: