mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
skip ss computation when not ss example
This commit is contained in:
@@ -179,39 +179,44 @@ class CalculateNucleicAcidGeomFeats(Transform):
|
||||
if not self.is_inference:
|
||||
|
||||
# First, annotate as usual
|
||||
atom_array = annotate_na_ss(atom_array,
|
||||
is_nucleic_ss_example, give_partial_feats = self._sample_training_flags()
|
||||
|
||||
if is_nucleic_ss_example:
|
||||
|
||||
atom_array = annotate_na_ss(atom_array,
|
||||
NA_only=self.NA_only,
|
||||
planar_only=self.planar_only,
|
||||
p_canonical_bp_filter=self.p_canonical_bp_filter,
|
||||
)
|
||||
|
||||
# Generate symmetric partner annotations at the token level for masking purposes.
|
||||
# choice for object-consistency: if already masked/undefined: be a list mapping to self-index.
|
||||
partner_sym_map = {
|
||||
i: atom_array.bp_partners[ts_i] if atom_array.bp_partners[ts_i] is not None else [i]
|
||||
for i, ts_i in enumerate(token_starts)
|
||||
}
|
||||
# Generate symmetric partner annotations at the token level for masking purposes.
|
||||
# choice for object-consistency: if already masked/undefined: be a list mapping to self-index.
|
||||
partner_sym_map = {
|
||||
i: atom_array.bp_partners[ts_i] if atom_array.bp_partners[ts_i] is not None else [i]
|
||||
for i, ts_i in enumerate(token_starts)
|
||||
}
|
||||
|
||||
# # Sample mask on token level:
|
||||
is_nucleic_ss_example, give_partial_feats = self._sample_training_flags()
|
||||
token_mask_to_show = self._sample_where_to_show_ss(
|
||||
n_tokens,
|
||||
is_nucleic_ss_example=is_nucleic_ss_example,
|
||||
give_partial_feats=give_partial_feats,
|
||||
partner_sym_map=partner_sym_map,
|
||||
) # Mask vec for tokens where ss shown
|
||||
# # Sample mask on token level:
|
||||
token_mask_to_show = self._sample_where_to_show_ss(
|
||||
n_tokens,
|
||||
is_nucleic_ss_example=is_nucleic_ss_example,
|
||||
give_partial_feats=give_partial_feats,
|
||||
partner_sym_map=partner_sym_map,
|
||||
) # Mask vec for tokens where ss shown
|
||||
|
||||
# Spread mask to atom level
|
||||
is_ss_shown = spread_token_wise(atom_array, token_mask_to_show)
|
||||
|
||||
# Extract the base pair annotations
|
||||
bp_partners_atom = atom_array.get_annotation("bp_partners")
|
||||
# Spread mask to atom level
|
||||
is_ss_shown = spread_token_wise(atom_array, token_mask_to_show)
|
||||
|
||||
# Extract the base pair annotations
|
||||
bp_partners_atom = atom_array.get_annotation("bp_partners")
|
||||
|
||||
# Remove unshown positions from bp_partners annotation
|
||||
bp_partners_atom[~is_ss_shown] = None
|
||||
|
||||
# Reset the annotation with newly hidden positions
|
||||
atom_array.set_annotation("bp_partners", bp_partners_atom)
|
||||
# Remove unshown positions from bp_partners annotation
|
||||
bp_partners_atom[~is_ss_shown] = None
|
||||
|
||||
# Reset the annotation with newly hidden positions
|
||||
atom_array.set_annotation("bp_partners", bp_partners_atom)
|
||||
else:
|
||||
atom_array.set_annotation("bp_partners", np.array([None]*len(atom_array)))
|
||||
|
||||
# Inference case: create from commandline args
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user