mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
Add nucleic SS metrics + conditioning + pseudoknot dataset config
This commit is contained in:
@@ -62,6 +62,9 @@ global_transform_args:
|
||||
min_ss_island_len: 1
|
||||
max_ss_island_len: 10
|
||||
|
||||
# Nucleic acid features
|
||||
add_na_pair_features: false
|
||||
|
||||
train_conditions:
|
||||
unconditional:
|
||||
frequency: 5.0
|
||||
|
||||
@@ -43,6 +43,9 @@ dataset:
|
||||
min_ss_island_len: ${datasets.global_transform_args.min_ss_island_len}
|
||||
max_ss_island_len: ${datasets.global_transform_args.max_ss_island_len}
|
||||
|
||||
# Nucleic acid features
|
||||
add_na_pair_features: ${datasets.global_transform_args.add_na_pair_features}
|
||||
|
||||
# Cropping
|
||||
crop_size: ${datasets.crop_size}
|
||||
max_atoms_in_crop: ${datasets.max_atoms_in_crop}
|
||||
@@ -56,4 +59,5 @@ dataset:
|
||||
|
||||
# Other dataset-specific parameters
|
||||
atom_1d_features: ${model.net.token_initializer.atom_1d_features}
|
||||
token_1d_features: ${model.net.token_initializer.token_1d_features}
|
||||
token_1d_features: ${model.net.token_initializer.token_1d_features}
|
||||
token_2d_features: ${model.net.token_initializer.token_2d_features}
|
||||
|
||||
@@ -37,4 +37,5 @@ dataset:
|
||||
|
||||
# Other dataset-specific parameters
|
||||
atom_1d_features: ${model.net.token_initializer.atom_1d_features}
|
||||
token_1d_features: ${model.net.token_initializer.token_1d_features}
|
||||
token_1d_features: ${model.net.token_initializer.token_1d_features}
|
||||
token_2d_features: ${model.net.token_initializer.token_2d_features}
|
||||
9
models/rfd3/configs/datasets/val/pseudoknot.yaml
Normal file
9
models/rfd3/configs/datasets/val/pseudoknot.yaml
Normal file
@@ -0,0 +1,9 @@
|
||||
|
||||
defaults:
|
||||
- unconditional
|
||||
- _self_
|
||||
|
||||
dataset:
|
||||
name: pseudoknot
|
||||
eval_every_n: 1
|
||||
data: /home/afavor/git/RFD3/modelhub/projects/aa_design/tests/test_data/pseudoknot.json
|
||||
87
models/rfd3/configs/experiment/rfd3na-ss.yaml
Normal file
87
models/rfd3/configs/experiment/rfd3na-ss.yaml
Normal file
@@ -0,0 +1,87 @@
|
||||
# @package _global_
|
||||
# Training configuration for RFD3
|
||||
|
||||
defaults:
|
||||
- /debug/default
|
||||
- override /model: rfd3_base
|
||||
- override /logger: null
|
||||
- _self_
|
||||
|
||||
name: rfd3na-SScond
|
||||
tags: [print-model]
|
||||
ckpt_path: null
|
||||
|
||||
model:
|
||||
net:
|
||||
token_initializer:
|
||||
token_1d_features:
|
||||
ref_motif_token_type: 3
|
||||
restype: 32
|
||||
is_dna_token: 1
|
||||
is_rna_token: 1
|
||||
is_protein_token: 1
|
||||
token_2d_features:
|
||||
bp_partners: 3 # Unspecified, pair, loop
|
||||
atom_1d_features:
|
||||
ref_atom_name_chars: 256
|
||||
ref_element: 128
|
||||
ref_charge: 1
|
||||
ref_mask: 1
|
||||
ref_is_motif_atom_with_fixed_coord: 1
|
||||
ref_is_motif_atom_unindexed: 1
|
||||
has_zero_occupancy: 1
|
||||
ref_pos: 3
|
||||
|
||||
# Guided features
|
||||
ref_atomwise_rasa: 3
|
||||
active_donor: 1
|
||||
active_acceptor: 1
|
||||
is_atom_level_hotspot: 1
|
||||
diffusion_module:
|
||||
n_recycle: 2
|
||||
use_local_token_attention: True
|
||||
diffusion_transformer:
|
||||
n_local_tokens: 32
|
||||
n_keys: 128
|
||||
|
||||
inference_sampler:
|
||||
num_timesteps: 100
|
||||
|
||||
|
||||
datasets:
|
||||
diffusion_batch_size_train: 16
|
||||
crop_size: 256
|
||||
max_atoms_in_crop: 2560 # ~10x crop size.
|
||||
global_transform_args:
|
||||
association_scheme: atom23
|
||||
add_na_pair_features: true
|
||||
train_conditions:
|
||||
unconditional:
|
||||
frequency: 2.0
|
||||
island:
|
||||
frequency: 2.0
|
||||
sequence_design:
|
||||
frequency: 0.5
|
||||
tipatom:
|
||||
frequency: 5.0
|
||||
ppi:
|
||||
frequency: 0.0
|
||||
train:
|
||||
# These are the ratios used in the preprint but we set all pdb sampling by default since not everyone might download the distillation data.
|
||||
pdb:
|
||||
probability: 1.0
|
||||
# rna_monomer_distillation:
|
||||
# probability: 1.0
|
||||
|
||||
# val:
|
||||
# pseudoknot:
|
||||
# dataset:
|
||||
# # eval_every_n: 10
|
||||
# eval_every_n: 2
|
||||
|
||||
trainer:
|
||||
devices_per_node: 1
|
||||
limit_train_batches: 10
|
||||
limit_val_batches: 1
|
||||
validate_every_n_epochs: 5
|
||||
prevalidate: false
|
||||
@@ -25,6 +25,9 @@ token_initializer: # formerly known as the trunk
|
||||
ref_plddt: 1
|
||||
is_non_loopy: 1
|
||||
|
||||
# Optional 2D token feature definitions (empty by default)
|
||||
token_2d_features: {}
|
||||
|
||||
downcast: ${model.net.diffusion_module.downcast}
|
||||
atom_1d_features:
|
||||
ref_atom_name_chars: 256
|
||||
|
||||
@@ -20,3 +20,12 @@ hbond_metrics:
|
||||
_target_: rfd3.metrics.hbonds_hbplus_metrics.HbondMetrics
|
||||
cutoff_HA_dist: 3
|
||||
cutoff_DA_distance: 3.5
|
||||
|
||||
nucleic_ss_similarity:
|
||||
_target_: rfd3.metrics.nucleic_ss_metrics.NucleicSSSimilarityMetrics
|
||||
restrict_to_nucleic: True
|
||||
compute_for_diffused_region_only: False
|
||||
annotate_predicted_fresh: True
|
||||
annotation_NA_only: False
|
||||
annotation_planar_only: True
|
||||
|
||||
|
||||
314
models/rfd3/src/rfd3/metrics/nucleic_ss_metrics.py
Normal file
314
models/rfd3/src/rfd3/metrics/nucleic_ss_metrics.py
Normal file
@@ -0,0 +1,314 @@
|
||||
import logging
|
||||
|
||||
import bdb
|
||||
import numpy as np
|
||||
from biotite.structure import AtomArray
|
||||
from atomworks.ml.utils.token import (
|
||||
get_token_starts,
|
||||
)
|
||||
|
||||
from rfd3.transforms.na_geom_utils import annotate_na_ss
|
||||
|
||||
from foundry.metrics.metric import Metric
|
||||
from foundry.utils.ddp import RankedLogger
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
global_logger = RankedLogger(__name__, rank_zero_only=False)
|
||||
|
||||
|
||||
def _safe_f1_from_sizes(intersection_n: int, pred_n: int, gt_n: int) -> float:
|
||||
"""Return F1 with sensible empty-set handling."""
|
||||
if pred_n == 0 and gt_n == 0:
|
||||
return 1.0
|
||||
|
||||
precision = float(intersection_n / pred_n) if pred_n > 0 else 0.0
|
||||
recall = float(intersection_n / gt_n) if gt_n > 0 else 0.0
|
||||
|
||||
if precision + recall == 0.0:
|
||||
return 0.0
|
||||
|
||||
return float(2.0 * precision * recall / (precision + recall))
|
||||
|
||||
|
||||
def _get_token_ids(atom_array: AtomArray) -> np.ndarray:
|
||||
token_starts = get_token_starts(atom_array)
|
||||
token_level_array = atom_array[token_starts]
|
||||
return np.asarray(token_level_array.token_id, dtype=int)
|
||||
|
||||
|
||||
def _get_candidate_token_ids(
|
||||
atom_array: AtomArray,
|
||||
*,
|
||||
restrict_to_nucleic: bool,
|
||||
compute_for_diffused_region_only: bool,
|
||||
) -> set[int]:
|
||||
"""Return a set of token_ids to include for scoring."""
|
||||
token_starts = get_token_starts(atom_array)
|
||||
token_level_array = atom_array[token_starts]
|
||||
token_ids = np.asarray(token_level_array.token_id, dtype=int)
|
||||
|
||||
token_mask = np.ones(len(token_ids), dtype=bool)
|
||||
|
||||
if restrict_to_nucleic:
|
||||
is_rna = (
|
||||
np.asarray(getattr(token_level_array, "is_rna"), dtype=bool)
|
||||
if hasattr(token_level_array, "is_rna")
|
||||
else np.zeros(len(token_ids), dtype=bool)
|
||||
)
|
||||
is_dna = (
|
||||
np.asarray(getattr(token_level_array, "is_dna"), dtype=bool)
|
||||
if hasattr(token_level_array, "is_dna")
|
||||
else np.zeros(len(token_ids), dtype=bool)
|
||||
)
|
||||
token_mask &= (is_rna | is_dna) if (is_rna.any() or is_dna.any()) else token_mask
|
||||
|
||||
if compute_for_diffused_region_only:
|
||||
if hasattr(token_level_array, "is_motif_atom"):
|
||||
token_mask &= ~np.asarray(token_level_array.is_motif_atom, dtype=bool)
|
||||
elif hasattr(token_level_array, "is_motif_token"):
|
||||
token_mask &= ~np.asarray(token_level_array.is_motif_token, dtype=bool)
|
||||
|
||||
return set(int(t) for t in token_ids[token_mask].tolist())
|
||||
|
||||
|
||||
def _extract_bp_pairs(
|
||||
atom_array: AtomArray,
|
||||
*,
|
||||
allowed_token_ids: set[int],
|
||||
) -> set[tuple[int, int]]:
|
||||
"""Extract unordered base-pair edges from bp_partner annotations.
|
||||
|
||||
Pairs are represented as (min_token_id, max_token_id).
|
||||
"""
|
||||
if "bp_partner" not in atom_array.get_annotation_categories():
|
||||
raise ValueError("atom_array missing bp_partner annotation")
|
||||
|
||||
token_starts = get_token_starts(atom_array)
|
||||
token_level_array = atom_array[token_starts]
|
||||
token_ids = np.asarray(token_level_array.token_id, dtype=int)
|
||||
token_id_to_pos = {int(tid): i for i, tid in enumerate(token_ids.tolist())}
|
||||
|
||||
bp_partner_ann = atom_array.bp_partner
|
||||
pairs: set[tuple[int, int]] = set()
|
||||
|
||||
for pos, start_idx in enumerate(token_starts.tolist()):
|
||||
i_tid = int(token_ids[pos])
|
||||
if i_tid not in allowed_token_ids:
|
||||
continue
|
||||
|
||||
partners = bp_partner_ann[int(start_idx)]
|
||||
if partners is None:
|
||||
continue
|
||||
if not isinstance(partners, (list, tuple, np.ndarray)):
|
||||
continue
|
||||
|
||||
for partner_token_id in partners:
|
||||
try:
|
||||
j_tid = int(partner_token_id)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if j_tid == i_tid or j_tid not in allowed_token_ids:
|
||||
continue
|
||||
|
||||
if j_tid not in token_id_to_pos:
|
||||
continue
|
||||
|
||||
a, b = (i_tid, j_tid) if i_tid < j_tid else (j_tid, i_tid)
|
||||
pairs.add((a, b))
|
||||
|
||||
return pairs
|
||||
|
||||
|
||||
def _extract_loop_and_paired_token_ids(
|
||||
atom_array: AtomArray,
|
||||
*,
|
||||
allowed_token_ids: set[int],
|
||||
) -> tuple[set[int], set[int]]:
|
||||
"""Return (loop_token_ids, paired_token_ids) within the allowed token set."""
|
||||
if "bp_partner" not in atom_array.get_annotation_categories():
|
||||
raise ValueError("atom_array missing bp_partner annotation")
|
||||
|
||||
token_starts = get_token_starts(atom_array)
|
||||
token_level_array = atom_array[token_starts]
|
||||
token_ids = np.asarray(token_level_array.token_id, dtype=int)
|
||||
token_id_to_pos = {int(tid): i for i, tid in enumerate(token_ids.tolist())}
|
||||
|
||||
bp_partner_ann = atom_array.bp_partner
|
||||
|
||||
loop_token_ids: set[int] = set()
|
||||
paired_token_ids: set[int] = set()
|
||||
|
||||
for pos, start_idx in enumerate(token_starts.tolist()):
|
||||
i_tid = int(token_ids[pos])
|
||||
if i_tid not in allowed_token_ids:
|
||||
continue
|
||||
|
||||
partners = bp_partner_ann[int(start_idx)]
|
||||
# New semantics:
|
||||
# - None => unannotated/masked (NOT a loop)
|
||||
# - [] => explicitly unpaired loop
|
||||
if partners is None:
|
||||
continue
|
||||
if not isinstance(partners, (list, tuple, np.ndarray)):
|
||||
continue
|
||||
if len(partners) == 0:
|
||||
loop_token_ids.add(i_tid)
|
||||
continue
|
||||
|
||||
for partner_token_id in partners:
|
||||
try:
|
||||
j_tid = int(partner_token_id)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if j_tid == i_tid or j_tid not in allowed_token_ids:
|
||||
continue
|
||||
if j_tid not in token_id_to_pos:
|
||||
continue
|
||||
paired_token_ids.add(i_tid)
|
||||
paired_token_ids.add(j_tid)
|
||||
|
||||
return loop_token_ids, paired_token_ids
|
||||
|
||||
|
||||
class NucleicSSSimilarityMetrics(Metric):
|
||||
"""Secondary-structure similarity for nucleic acids.
|
||||
|
||||
Reports:
|
||||
- `pair_f1`: F1 over the set of basepair edges implied by token-level `bp_partner`.
|
||||
- `loop_f1`: F1 over explicitly-unpaired loop tokens (`bp_partner == []`).
|
||||
Unannotated tokens (`bp_partner is None`) are masked.
|
||||
- `weighted_f1`: GT-weighted average of `pair_f1` and `loop_f1`, weighted by
|
||||
the prevalence of paired vs loop tokens in the GT.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
restrict_to_nucleic: bool = True,
|
||||
compute_for_diffused_region_only: bool = False,
|
||||
annotate_predicted_fresh: bool = False,
|
||||
annotation_NA_only: bool = False,
|
||||
annotation_planar_only: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.restrict_to_nucleic = restrict_to_nucleic
|
||||
self.compute_for_diffused_region_only = compute_for_diffused_region_only
|
||||
self.annotate_predicted_fresh = annotate_predicted_fresh
|
||||
self.annotation_NA_only = annotation_NA_only
|
||||
self.annotation_planar_only = annotation_planar_only
|
||||
|
||||
@property
|
||||
def kwargs_to_compute_args(self):
|
||||
return {
|
||||
"ground_truth_atom_array_stack": ("ground_truth_atom_array_stack",),
|
||||
"predicted_atom_array_stack": ("predicted_atom_array_stack",),
|
||||
}
|
||||
|
||||
def compute(self, *, ground_truth_atom_array_stack, predicted_atom_array_stack):
|
||||
if ground_truth_atom_array_stack is None or predicted_atom_array_stack is None:
|
||||
return {}
|
||||
|
||||
pair_f1_list: list[float] = []
|
||||
loop_f1_list: list[float] = []
|
||||
weighted_f1_list: list[float] = []
|
||||
|
||||
n_valid = 0
|
||||
|
||||
for gt_arr, pred_arr in zip(ground_truth_atom_array_stack, predicted_atom_array_stack):
|
||||
try:
|
||||
if "bp_partner" not in gt_arr.get_annotation_categories():
|
||||
continue
|
||||
|
||||
# Important: predicted AtomArrays are built from a template AtomArray.
|
||||
# If that template already carries bp_partner (often GT-derived), the
|
||||
# prediction can inherit it, yielding artificially perfect scores.
|
||||
# Optionally recompute bp_partner from the *predicted coordinates*.
|
||||
|
||||
if self.annotate_predicted_fresh:
|
||||
annotate_na_ss(
|
||||
pred_arr,
|
||||
NA_only=self.annotation_NA_only,
|
||||
planar_only=self.annotation_planar_only,
|
||||
overwrite=True,
|
||||
p_canonical_bp_filter=0.0,
|
||||
)
|
||||
|
||||
if "bp_partner" not in pred_arr.get_annotation_categories():
|
||||
continue
|
||||
|
||||
# Basic sanity check: token counts should match for aligned comparisons.
|
||||
gt_token_ids = _get_token_ids(gt_arr)
|
||||
pred_token_ids = _get_token_ids(pred_arr)
|
||||
if len(gt_token_ids) != len(pred_token_ids):
|
||||
continue
|
||||
|
||||
# Restrict to token_ids that are valid in both arrays.
|
||||
gt_allowed = _get_candidate_token_ids(
|
||||
gt_arr,
|
||||
restrict_to_nucleic=self.restrict_to_nucleic,
|
||||
compute_for_diffused_region_only=self.compute_for_diffused_region_only,
|
||||
)
|
||||
pred_allowed = _get_candidate_token_ids(
|
||||
pred_arr,
|
||||
restrict_to_nucleic=self.restrict_to_nucleic,
|
||||
compute_for_diffused_region_only=self.compute_for_diffused_region_only,
|
||||
)
|
||||
allowed = gt_allowed & pred_allowed
|
||||
|
||||
if len(allowed) == 0:
|
||||
continue
|
||||
|
||||
gt_pairs = _extract_bp_pairs(gt_arr, allowed_token_ids=allowed)
|
||||
pred_pairs = _extract_bp_pairs(pred_arr, allowed_token_ids=allowed)
|
||||
|
||||
gt_loop, gt_paired_tokens = _extract_loop_and_paired_token_ids(
|
||||
gt_arr, allowed_token_ids=allowed
|
||||
)
|
||||
pred_loop, _pred_paired_tokens = _extract_loop_and_paired_token_ids(
|
||||
pred_arr, allowed_token_ids=allowed
|
||||
)
|
||||
|
||||
pair_tp = len(gt_pairs & pred_pairs)
|
||||
pair_pred_n = len(pred_pairs)
|
||||
pair_gt_n = len(gt_pairs)
|
||||
|
||||
loop_tp = len(gt_loop & pred_loop)
|
||||
loop_pred_n = len(pred_loop)
|
||||
loop_gt_n = len(gt_loop)
|
||||
|
||||
pair_f1 = _safe_f1_from_sizes(pair_tp, pair_pred_n, pair_gt_n)
|
||||
loop_f1 = _safe_f1_from_sizes(loop_tp, loop_pred_n, loop_gt_n)
|
||||
|
||||
pair_weight = len(gt_paired_tokens)
|
||||
loop_weight = len(gt_loop)
|
||||
total_weight = pair_weight + loop_weight
|
||||
if total_weight == 0:
|
||||
weighted_f1 = 1.0
|
||||
else:
|
||||
weighted_f1 = float(
|
||||
(pair_weight * pair_f1 + loop_weight * loop_f1) / total_weight
|
||||
)
|
||||
|
||||
pair_f1_list.append(pair_f1)
|
||||
loop_f1_list.append(loop_f1)
|
||||
weighted_f1_list.append(weighted_f1)
|
||||
n_valid += 1
|
||||
|
||||
except bdb.BdbQuit:
|
||||
# Allow interactive debuggers (pdb) to cleanly abort without being swallowed.
|
||||
raise
|
||||
except Exception as e:
|
||||
global_logger.error(f"Error computing nucleic-SS similarity: {e} | Skipping")
|
||||
continue
|
||||
|
||||
if n_valid == 0:
|
||||
return {}
|
||||
|
||||
return {
|
||||
"pair_f1": float(np.mean(pair_f1_list)),
|
||||
"loop_f1": float(np.mean(loop_f1_list)),
|
||||
"weighted_f1": float(np.mean(weighted_f1_list)),
|
||||
"n_valid_samples": int(n_valid),
|
||||
}
|
||||
@@ -143,6 +143,38 @@ class OneDFeatureEmbedder(nn.Module):
|
||||
)
|
||||
)
|
||||
|
||||
class TwoDFeatureEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds 2D features into a single vector.
|
||||
|
||||
Args:
|
||||
features (dict): Dictionary of feature names and their number of channels.
|
||||
output_channels (int): Output dimension of the projected embedding.
|
||||
"""
|
||||
|
||||
def __init__(self, features, output_channels):
|
||||
super().__init__()
|
||||
self.features = {k: v for k, v in features.items() if exists(v)}
|
||||
total_embedding_input_features = sum(self.features.values())
|
||||
self.embedders = nn.ModuleDict(
|
||||
{
|
||||
feature: EmbeddingLayer(
|
||||
n_channels, total_embedding_input_features, output_channels
|
||||
)
|
||||
for feature, n_channels in self.features.items()
|
||||
}
|
||||
)
|
||||
def collapse2D(self, x, L):
|
||||
return x.reshape((L, L, x.numel() // (L * L)))
|
||||
|
||||
def forward(self, f, collapse_length):
|
||||
return sum(
|
||||
tuple(
|
||||
self.embedders[feature](self.collapse2D(f[feature].float(), collapse_length))
|
||||
for feature, n_channels in self.features.items()
|
||||
if exists(n_channels)
|
||||
)
|
||||
)
|
||||
|
||||
class SinusoidalDistEmbed(nn.Module):
|
||||
"""
|
||||
|
||||
@@ -11,6 +11,7 @@ from rfd3.model.layers.blocks import (
|
||||
Downcast,
|
||||
LocalAtomTransformer,
|
||||
OneDFeatureEmbedder,
|
||||
TwoDFeatureEmbedder,
|
||||
PositionPairDistEmbedder,
|
||||
RelativePositionEncodingWithIndexRemoval,
|
||||
SinusoidalDistEmbed,
|
||||
@@ -49,6 +50,7 @@ class TokenInitializer(nn.Module):
|
||||
pairformer_block,
|
||||
downcast,
|
||||
token_1d_features,
|
||||
token_2d_features,
|
||||
atom_1d_features,
|
||||
atom_transformer,
|
||||
use_chunked_pll=False, # New parameter for memory optimization
|
||||
@@ -62,6 +64,7 @@ class TokenInitializer(nn.Module):
|
||||
self.atom_1d_embedder_1 = OneDFeatureEmbedder(atom_1d_features, c_s)
|
||||
self.atom_1d_embedder_2 = OneDFeatureEmbedder(atom_1d_features, c_atom)
|
||||
self.token_1d_embedder = OneDFeatureEmbedder(token_1d_features, c_s)
|
||||
self.token_2d_embedder = TwoDFeatureEmbedder(token_2d_features, c_z)
|
||||
|
||||
self.downcast_atom = Downcast(c_atom=c_s, c_token=c_s, c_s=None, **downcast)
|
||||
self.transition_post_token = Transition(c=c_s, n=2)
|
||||
@@ -202,6 +205,8 @@ class TokenInitializer(nn.Module):
|
||||
Z_init_II = Z_init_II + self.ref_pos_embedder_tok(
|
||||
f["ref_pos"][f["is_ca"]], valid_mask
|
||||
)
|
||||
# Add extra token pair features
|
||||
Z_init_II = Z_init_II + self.token_2d_embedder(f, I)
|
||||
|
||||
# Run a small transformer to provide position encodings to single.
|
||||
for block in self.transformer_stack:
|
||||
|
||||
@@ -45,6 +45,7 @@ from rfd3.transforms.util_transforms import (
|
||||
get_af3_token_representative_masks,
|
||||
)
|
||||
from rfd3.transforms.virtual_atoms import PadTokensWithVirtualAtoms
|
||||
from rfd3.transforms.na_geom import get_bp_feats_from_atom_array
|
||||
|
||||
from foundry.utils.ddp import RankedLogger # noqa
|
||||
|
||||
@@ -776,6 +777,90 @@ class AddAdditional1dFeaturesToFeats(Transform):
|
||||
data = self.generate_feature(feature_name, n_dims, data, "atom")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
|
||||
class AddAdditional2dFeaturesToFeats(Transform):
|
||||
"""
|
||||
Adds any net.token_initializer.token_2d_features and net.diffusion_module.diffusion_atom_encoder.atom_2d_features present in the atomarray but not in data['feats'] to data['feats']
|
||||
Args:
|
||||
- autofill_zeros_if_not_present_in_atomarray: self explanatory
|
||||
- token_2d_features: List of single-item dictionaries, corresponding to feature_name: n_feature_dims. Should be hydra interpolated from
|
||||
net.token_initializer.token_2d_features
|
||||
"""
|
||||
|
||||
incompatible_previous_transforms = ["AddAdditional2dFeaturesToFeats"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token_2d_features,
|
||||
autofill_zeros_if_not_present_in_atomarray=False,
|
||||
association_scheme="atom14",
|
||||
):
|
||||
self.autofill = autofill_zeros_if_not_present_in_atomarray
|
||||
self.token_2d_features = token_2d_features
|
||||
self.association_scheme = association_scheme
|
||||
|
||||
# Need to pre-define custom constructor functions
|
||||
# to map from atomarray annotations to tensors.
|
||||
self.constructor_functions = {
|
||||
'bp_partners': get_bp_feats_from_atom_array,
|
||||
}
|
||||
|
||||
def check_input(self, data) -> None:
|
||||
check_contains_keys(data, ["atom_array"])
|
||||
check_is_instance(data, "atom_array", AtomArray)
|
||||
|
||||
def generate_token_feature(self, feature_name, n_dims, data):
|
||||
|
||||
# Don't do this if we already have the feature
|
||||
if feature_name in data["feats"].keys():
|
||||
return data
|
||||
|
||||
# For these, we need to use a constructor function mapping,
|
||||
# since pair features may require custom logic/conventions.
|
||||
if feature_name in self.constructor_functions.keys():
|
||||
feature_array = self.constructor_functions[feature_name](data["atom_array"])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No constructor function found for 2d feature `{feature_name}`"
|
||||
)
|
||||
|
||||
# We can fix shape issues here:
|
||||
if len(feature_array.shape) == 2 and n_dims == 1:
|
||||
feature_array = feature_array.unsqueeze(1)
|
||||
|
||||
# ensure that feature_array is a 3d array with third dim == n_dims:
|
||||
if len(feature_array.shape) != 3:
|
||||
raise ValueError(
|
||||
f"token 2d_feature `{feature_name}` must be a 3d array, got {len(feature_array.shape)}d."
|
||||
)
|
||||
if feature_array.shape[2] != n_dims:
|
||||
raise ValueError(
|
||||
f"token 2d_feature `{feature_name}` dimensions in atomarray ({feature_array.shape[-1]}) does not match dimension declared in config, ({n_dims})"
|
||||
)
|
||||
# Ensure correct shape in first two dims (I,I,...)
|
||||
if feature_array.shape[0] != feature_array.shape[1]:
|
||||
raise ValueError(
|
||||
f"token 2d_feature `{feature_name}` first two dimensions must be equal (square matrix), got {feature_array.shape[0]} and {feature_array.shape[1]}"
|
||||
)
|
||||
|
||||
data["feats"][feature_name] = feature_array
|
||||
|
||||
return data
|
||||
|
||||
def forward(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Checks if the 2d_features are present in data['feats']. If not present, adds them from the atomarray.
|
||||
If annotation is not present in atomarray, either autofills the feature with 0s or throws an error
|
||||
"""
|
||||
if "feats" not in data.keys():
|
||||
data["feats"] = {}
|
||||
# Only apply for features that the model is expecting:
|
||||
for feature_name, n_dims in self.token_2d_features.items():
|
||||
data = self.generate_token_feature(feature_name, n_dims, data)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class FeaturizepLDDT(Transform):
|
||||
|
||||
285
models/rfd3/src/rfd3/transforms/na_geom.py
Normal file
285
models/rfd3/src/rfd3/transforms/na_geom.py
Normal file
@@ -0,0 +1,285 @@
|
||||
from typing import Any
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
from biotite.structure import AtomArray
|
||||
from atomworks.ml.transforms._checks import (
|
||||
check_atom_array_annotation,
|
||||
check_contains_keys,
|
||||
check_is_instance,
|
||||
)
|
||||
from atomworks.ml.transforms.base import Transform
|
||||
from rfd3.transforms.conditioning_utils import sample_island_tokens
|
||||
from rfd3.transforms.na_geom_utils import (
|
||||
annotate_na_ss,
|
||||
annotate_na_ss_from_data_specification,
|
||||
bp_partner_to_ss_matrix,
|
||||
)
|
||||
|
||||
from atomworks.ml.utils.token import spread_token_wise, get_token_starts
|
||||
|
||||
def get_bp_feats_from_atom_array(
|
||||
atom_array: AtomArray,
|
||||
) -> np.ndarray:
|
||||
"""Build NA-SS features from atom_array annotations, assuming 'bp_partners' is present.
|
||||
|
||||
This function reconstructs the SS matrix from the 'bp_partners' annotation on the atom_array,
|
||||
then one-hot encodes it into a 3-class matrix (mask, pair, loop).
|
||||
"""
|
||||
# Fixed feature info (inferred from usage in other functions)
|
||||
feature_info = {
|
||||
'NA_SS_MASK': 0, # Unspecified
|
||||
'NA_SS_PAIR': 1, # Paired
|
||||
'NA_SS_LOOP': 2, # Loop / unpaired
|
||||
'num_classes_nucleic_ss': 3,
|
||||
}
|
||||
|
||||
# Check for required annotation
|
||||
if "bp_partners" not in atom_array.get_annotation_categories():
|
||||
raise ValueError("atom_array must have 'bp_partners' annotation for NA-SS feature building.")
|
||||
|
||||
# Reconstruct SS matrix from annotations
|
||||
na_ss_matrix = np.asarray(
|
||||
bp_partner_to_ss_matrix(
|
||||
atom_array,
|
||||
feature_info=feature_info,
|
||||
NA_only=False, # Include all residues (logic from other utils)
|
||||
planar_only=True, # Use planar interactions (common default)
|
||||
include_loops=True, # Include loop states
|
||||
),
|
||||
dtype=np.int64,
|
||||
)
|
||||
|
||||
# One-hot encode the matrix
|
||||
na_ss_matrix_int = np.asarray(na_ss_matrix, dtype=np.int64)
|
||||
eye = np.eye(int(feature_info['num_classes_nucleic_ss']), dtype=np.int64)
|
||||
return eye[na_ss_matrix_int]
|
||||
|
||||
|
||||
def _build_na_ss_features_from_annotations(
|
||||
atom_array: AtomArray,
|
||||
*,
|
||||
feature_info: dict,
|
||||
num_classes: int,
|
||||
NA_only: bool,
|
||||
planar_only: bool,
|
||||
is_nucleic_ss_example: bool,
|
||||
give_partial_feats: bool,
|
||||
get_feature_mask_fn,
|
||||
) -> np.ndarray:
|
||||
"""Reconstruct SS matrix from annotations, optionally mask, then one-hot."""
|
||||
na_ss_matrix = np.asarray(
|
||||
bp_partner_to_ss_matrix(
|
||||
atom_array,
|
||||
feature_info=feature_info,
|
||||
NA_only=NA_only,
|
||||
planar_only=planar_only,
|
||||
include_loops=True,
|
||||
),
|
||||
dtype=np.int64,
|
||||
)
|
||||
|
||||
n_tokens = int(na_ss_matrix.shape[0])
|
||||
|
||||
if give_partial_feats:
|
||||
is_shown = (
|
||||
np.asarray(get_feature_mask_fn(n_tokens), dtype=bool)
|
||||
if is_nucleic_ss_example
|
||||
else np.zeros((n_tokens,), dtype=bool)
|
||||
)
|
||||
na_ss_matrix[~is_shown, :] = feature_info["NA_SS_MASK"]
|
||||
na_ss_matrix[:, ~is_shown] = feature_info["NA_SS_MASK"]
|
||||
|
||||
na_ss_matrix_int = np.asarray(na_ss_matrix, dtype=np.int64)
|
||||
eye = np.eye(int(num_classes), dtype=np.int64)
|
||||
return eye[na_ss_matrix_int]
|
||||
|
||||
|
||||
class CalculateNucleicAcidGeomFeats(Transform):
|
||||
"""
|
||||
Transform for constructing nucleic-acid conditioning features.
|
||||
|
||||
This transform currently produces only nucleic-acid secondary-structure (NA-SS)
|
||||
features as a 2D token-token matrix with 3 bins:
|
||||
* 0: mask / unspecified
|
||||
* 1: paired
|
||||
* 2: loop / explicitly unpaired
|
||||
|
||||
Training:
|
||||
- Computes geometry/H-bond-based base pairs and writes them onto the AtomArray
|
||||
via the ``bp_partner`` annotation (annotation-first), then reconstructs the
|
||||
matrix (and optionally masks parts of it) before one-hot encoding.
|
||||
|
||||
Inference:
|
||||
- Interprets user-provided secondary-structure specifications, writes the same
|
||||
``bp_partner`` annotation, then follows the same matrix + one-hot path.
|
||||
|
||||
Note: helical-parameter features are not implemented/used in this refactored path.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_inference,
|
||||
add_nucleic_ss_feats: bool = True,
|
||||
|
||||
p_is_nucleic_ss_example: float = 0.3,
|
||||
p_show_partial_feats: float = 0.5,
|
||||
nucleic_ss_min_shown: float = 0.0,
|
||||
nucleic_ss_max_shown: float = 1.0,
|
||||
n_islands_min: int = 1,
|
||||
n_islands_max: int = 6,
|
||||
p_canonical_bp_filter: float = 0.0,
|
||||
|
||||
# USE_RF2AA_NAMES: bool = False,
|
||||
NA_only: bool = False,
|
||||
planar_only : bool = True,
|
||||
|
||||
):
|
||||
# Critical, must always have to know how to handle
|
||||
self.is_inference = is_inference
|
||||
|
||||
# For sampling whether we add nucleic-ss features (extra t2d)
|
||||
self.add_nucleic_ss_feats = add_nucleic_ss_feats
|
||||
self.p_canonical_bp_filter = p_canonical_bp_filter # enforce that bp labels are only canonical
|
||||
self.p_is_nucleic_ss_example = p_is_nucleic_ss_example
|
||||
self.nucleic_ss_min_shown = nucleic_ss_min_shown
|
||||
self.nucleic_ss_max_shown = nucleic_ss_max_shown
|
||||
self.n_islands_min = n_islands_min
|
||||
self.n_islands_max = n_islands_max
|
||||
|
||||
self.p_show_partial_feats = p_show_partial_feats
|
||||
|
||||
# Filters for what can be considered a planar contact interaction
|
||||
self.NA_only = NA_only # only annotate base-like interactions for nucleic acid residues
|
||||
self.planar_only = planar_only # only consider planar atoms in sidechains for geometry calculations,
|
||||
self.p_canonical_bp_filter = p_canonical_bp_filter # probability of enforcing canonical base pair filter
|
||||
|
||||
# Inds of annotation types in the nucleic-ss features (stack of 3 matrices):
|
||||
self.feature_info = {
|
||||
'NA_SS_MASK' : 0, # Unspecified, or sm, or protein:
|
||||
'NA_SS_PAIR' : 1,
|
||||
'NA_SS_LOOP' : 2,
|
||||
'num_classes_nucleic_ss' : 3,
|
||||
}
|
||||
|
||||
|
||||
def check_input(self, data: dict[str, Any]) -> None:
|
||||
check_contains_keys(data, ["atom_array"])
|
||||
check_is_instance(data, "atom_array", AtomArray)
|
||||
check_atom_array_annotation(data, ["res_name"])
|
||||
# maybe do later: check_atom_array_has_hydrogen(data)
|
||||
|
||||
def _sample_training_flags(self) -> tuple[bool, bool]:
|
||||
"""Sample booleans controlling whether/how features are shown in training."""
|
||||
is_nucleic_ss_example = bool(
|
||||
self.add_nucleic_ss_feats
|
||||
and (np.random.rand() < self.p_is_nucleic_ss_example)
|
||||
)
|
||||
give_partial_feats = bool(
|
||||
np.random.rand() < self.p_show_partial_feats
|
||||
)
|
||||
return is_nucleic_ss_example, give_partial_feats
|
||||
|
||||
def forward(self, data: dict) -> dict:
|
||||
atom_array = data["atom_array"]
|
||||
|
||||
# Calculate n_tokens (assuming one token per residue for simplicity)
|
||||
token_starts = get_token_starts(atom_array)
|
||||
token_level_array = atom_array[token_starts]
|
||||
token_ids = [int(t) for t in token_level_array.token_id]
|
||||
n_tokens = len(token_starts)
|
||||
print(" DO I NEED TO CHANGE TO TOKEN_ID???")
|
||||
# Handle the training case with ground truth and masking:
|
||||
if not self.is_inference:
|
||||
|
||||
# First, annotate as usual
|
||||
# atom_array = annotate_na_ss(atom_array, **kwargs)
|
||||
atom_array = annotate_na_ss(atom_array,
|
||||
NA_only=self.NA_only,
|
||||
planar_only=self.planar_only,
|
||||
p_canonical_bp_filter=self.p_canonical_bp_filter,
|
||||
)
|
||||
|
||||
# Sample mask on token level:
|
||||
is_nucleic_ss_example, give_partial_feats = self._sample_training_flags()
|
||||
is_ss_shown = self._sample_where_to_show_ss(n_tokens,
|
||||
is_nucleic_ss_example=is_nucleic_ss_example,
|
||||
give_partial_feats=give_partial_feats) # Mask vec for tokens where ss shown
|
||||
# Spread mask to atom level
|
||||
is_ss_shown = spread_token_wise(atom_array, is_ss_shown)
|
||||
|
||||
|
||||
# Extract the base pair annotations
|
||||
bp_partners_atom = atom_array.get_annotation("bp_partners")
|
||||
|
||||
# Remove unshown positions from bp_partners annotation
|
||||
bp_partners_atom[~is_ss_shown] = None
|
||||
|
||||
# Reset the annotation with newly hidden positions
|
||||
atom_array.set_annotation("bp_partners", bp_partners_atom)
|
||||
|
||||
# Inference case: create from commandline args
|
||||
else:
|
||||
"""
|
||||
Different cases handled:
|
||||
- 1). Single dot-bracket string
|
||||
- 2). multiple dot bracket strings with chain/ind ranges specified
|
||||
- 3). Lists of paired indices
|
||||
|
||||
"""
|
||||
is_nucleic_ss_example=True
|
||||
give_partial_feats=False
|
||||
atom_array = annotate_na_ss_from_data_specification(
|
||||
data,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
# Check feats existence and update:
|
||||
if "feats" not in data:
|
||||
data["feats"] = {}
|
||||
|
||||
# data["feats"].update(nucleic_features)
|
||||
data.setdefault("log_dict", {})
|
||||
log_dict = data["log_dict"]
|
||||
data["log_dict"] = log_dict
|
||||
data["atom_array"] = atom_array
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _sample_where_to_show_ss(self, n_tokens: int,
|
||||
is_nucleic_ss_example: bool = True,
|
||||
give_partial_feats: bool = True,
|
||||
) -> np.ndarray:
|
||||
"""Sample token-level islands indicating which SS rows/cols to reveal."""
|
||||
# If NOT is_nucleic_ss_example, set is_shown to all False
|
||||
if not is_nucleic_ss_example:
|
||||
return np.zeros((n_tokens,), dtype=bool)
|
||||
|
||||
# If NOT give_partial_feats, set is_shown to all True
|
||||
if not give_partial_feats:
|
||||
return np.ones((n_tokens,), dtype=bool)
|
||||
else:
|
||||
frac_shown = (
|
||||
self.nucleic_ss_min_shown
|
||||
+ (self.nucleic_ss_max_shown - self.nucleic_ss_min_shown) * np.random.rand()
|
||||
)
|
||||
frac_shown = float(np.clip(frac_shown, 0.0, 1.0))
|
||||
max_length = int(np.ceil(frac_shown * n_tokens))
|
||||
if max_length <= 0:
|
||||
return np.zeros((n_tokens,), dtype=bool)
|
||||
|
||||
island_len_min = max(1, int(frac_shown * n_tokens // max(int(self.n_islands_max), 1)))
|
||||
island_len_max = max(1, int(frac_shown * n_tokens // max(int(self.n_islands_min), 1)))
|
||||
island_len_min = min(island_len_min, n_tokens)
|
||||
island_len_max = min(island_len_max, n_tokens)
|
||||
island_len_max = max(island_len_max, island_len_min)
|
||||
|
||||
return sample_island_tokens(
|
||||
n_tokens,
|
||||
island_len_min=island_len_min,
|
||||
island_len_max=island_len_max,
|
||||
n_islands_min=self.n_islands_min,
|
||||
n_islands_max=self.n_islands_max,
|
||||
max_length=max_length,
|
||||
)
|
||||
|
||||
1708
models/rfd3/src/rfd3/transforms/na_geom_utils.py
Normal file
1708
models/rfd3/src/rfd3/transforms/na_geom_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -72,6 +72,7 @@ from rfd3.transforms.conditioning_base import (
|
||||
)
|
||||
from rfd3.transforms.design_transforms import (
|
||||
AddAdditional1dFeaturesToFeats,
|
||||
AddAdditional2dFeaturesToFeats,
|
||||
AddGroundTruthSequence,
|
||||
AddIsXFeats,
|
||||
AssignTypes,
|
||||
@@ -84,6 +85,7 @@ from rfd3.transforms.design_transforms import (
|
||||
)
|
||||
from rfd3.transforms.dna_crop import ProteinDNAContactContiguousCrop
|
||||
from rfd3.transforms.hbonds_hbplus import CalculateHbondsPlus
|
||||
from rfd3.transforms.na_geom import CalculateNucleicAcidGeomFeats
|
||||
from rfd3.transforms.ppi_transforms import (
|
||||
Add1DSSFeature,
|
||||
AddGlobalIsNonLoopyFeature,
|
||||
@@ -350,6 +352,7 @@ def build_atom14_base_pipeline_(
|
||||
center_option: str,
|
||||
atom_1d_features: dict | None,
|
||||
token_1d_features: dict | None,
|
||||
token_2d_features: dict | None,
|
||||
# PPI features
|
||||
max_ppi_hotspots_frac_to_provide: float,
|
||||
ppi_hotspot_max_distance: float,
|
||||
@@ -357,6 +360,8 @@ def build_atom14_base_pipeline_(
|
||||
max_ss_frac_to_provide: float,
|
||||
min_ss_island_len: int,
|
||||
max_ss_island_len: int,
|
||||
# Nucleic acid features
|
||||
add_na_pair_features: bool,
|
||||
**_, # dump additional kwargs (e.g. msa stuff)
|
||||
):
|
||||
"""
|
||||
@@ -437,6 +442,15 @@ def build_atom14_base_pipeline_(
|
||||
),
|
||||
)
|
||||
)
|
||||
# Add nucleic acid geometry features
|
||||
if add_na_pair_features:
|
||||
transforms.append(
|
||||
CalculateNucleicAcidGeomFeats(
|
||||
is_inference,
|
||||
NA_only=False,
|
||||
planar_only=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Design Transforms
|
||||
transforms += [
|
||||
@@ -525,6 +539,11 @@ def build_atom14_base_pipeline_(
|
||||
atom_1d_features=atom_1d_features,
|
||||
association_scheme=association_scheme,
|
||||
),
|
||||
AddAdditional2dFeaturesToFeats(
|
||||
autofill_zeros_if_not_present_in_atomarray=True,
|
||||
token_2d_features=token_2d_features,
|
||||
association_scheme=association_scheme,
|
||||
),
|
||||
AddAF3TokenBondFeatures(),
|
||||
AddGroundTruthSequence(sequence_encoding=af3_sequence_encoding),
|
||||
ConditionalRoute(
|
||||
@@ -615,6 +634,7 @@ def build_atom14_base_pipeline(
|
||||
kwargs.setdefault("min_ss_island_len", 0)
|
||||
kwargs.setdefault("max_ss_island_len", 999)
|
||||
kwargs.setdefault("max_binder_length", 999)
|
||||
kwargs.setdefault("add_na_pair_features", False)
|
||||
|
||||
kwargs.setdefault("b_factor_min", None)
|
||||
kwargs.setdefault("zero_occ_on_exposure_after_cropping", False)
|
||||
|
||||
Reference in New Issue
Block a user