add examples of getting sae features

This commit is contained in:
Neil Thomas
2026-03-17 18:16:52 -07:00
parent 6596fc1f30
commit ddffe05768
3 changed files with 234 additions and 0 deletions

70
cookbook/snippets/sae.py Normal file
View File

@@ -0,0 +1,70 @@
import numpy as np
import torch
from esm.sdk import batch_executor
from esm.sdk.api import (
ESMProtein,
ESMProteinError,
LogitsConfig,
SAEConfig,
)
from esm.sdk.forge import ESMCForgeInferenceClient
from cookbook.snippets.sparse_utils import max_pool, remove_indexes
def get_sae_features_single(
client: ESMCForgeInferenceClient,
sae_config: SAEConfig,
sequence: str,
pool: bool = True,
) -> torch.tensor:
protein = ESMProtein(sequence=sequence)
protein_tensor = client.encode(protein)
if isinstance(protein_tensor, ESMProteinError):
raise ValueError(
f"Error encoding sequence {sequence}: {protein_tensor.error_msg}"
)
# We wrap the SAEConfig in the LogitsConfig, which is normally used to return embeddings and hidden states.
output = client.logits(
protein_tensor,
config=LogitsConfig(sae_config=sae_config),
return_bytes=False,
)
if isinstance(output, ESMProteinError):
raise ValueError(
f"Error getting SAE features for sequence {sequence}: {output.error_msg}"
)
if output.sae_outputs is None:
raise ValueError(f"SAE outputs missing for sequence {sequence}: {output}")
sae_tensor = output.sae_outputs[sae_config.model]
if pool:
# Remove BOS / EOS tokens before pooling.
sae_features = remove_indexes(sae_tensor, {0, -1})
pooled_sae_features = max_pool(sae_features, axis=0)
return pooled_sae_features
else:
return sae_tensor
def get_sae_features(
client: ESMCForgeInferenceClient,
sae_config: SAEConfig,
sequences: list[str],
pool: bool = True,
) -> list[np.ndarray]:
with batch_executor() as executor:
results = executor.execute_batch(
user_func=get_sae_features_single,
client=client,
sae_config=sae_config,
sequence=sequences,
pool=pool,
)
# Re-raise any errors from the batch
for result in results:
if isinstance(result, Exception):
raise result
return results

View File

@@ -0,0 +1,41 @@
import os
from esm.sdk.api import SAEConfig
from esm.sdk.forge import ESMCForgeInferenceClient
from cookbook.snippets.sae import get_sae_features, get_sae_features_single
from cookbook.snippets.sparse_utils import remove_indexes
# Create ESMC 600M client
client = ESMCForgeInferenceClient(
model="esmc-600m-2024-12",
url="https://forge.evolutionaryscale.ai",
token=os.environ["ESM_API_KEY"],
)
# normalize feature activations by TF-IDF. Upweights activations
# of more highly specific features
sae_config = SAEConfig(
model="esmc-600m-2024-12_k64_codebook16384_layer27",
normalize_features=True,
)
# Create a protein
sequence = "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQATHVDQWDWEWAGIKATEAFLPDYPDLDA"
sequences = [sequence] * 10
# get unpooled features for a single sequence
unpooled_features = get_sae_features_single(client, sae_config, sequence, pool=False)
print(f"Got unpooled SAE features with shape {unpooled_features.shape}")
print(f"is_sparse: {unpooled_features.is_sparse}")
print(f"layout: {unpooled_features.layout}")
# To remove bos/eos tokens efficiently from sparse tensors, we use a utility
unpooled_features = remove_indexes(unpooled_features, {0, -1})
print(f"Unpooled SAE features after removing BOS/EOS have shape {unpooled_features.shape}")
# get pooled features for a batch
# this function pools by default to save memory.
features = get_sae_features(client, sae_config, sequences)
print(f"Got SAE features for {len(features)} sequences, each with shape {features[0].shape}")

View File

@@ -0,0 +1,123 @@
from typing import Iterable
import torch
def remove_indexes(
sparse_coo_tensor: torch.Tensor, indexes_to_remove: Iterable[int]
) -> torch.Tensor:
"""Remove entries at specified position indexes from sparse features.
This function removes positions and remaps the remaining indices to be contiguous.
For example, if we remove position 1 from a tensor with positions [0, 1, 2, 3],
the result will have positions [0, 1, 2] (where old position 2 becomes new position 1).
For example, remove_indexes(x, [0, -1]) will return the equivalent of tensor.to_dense().numpy()[1:-1]
Args:
sparse_coo_tensor: A sparse COO tensor of shape (num_positions, num_features)
indexes_to_remove: Iterable of position indexes to remove (supports negative indexing)
Returns:
A new sparse COO tensor with the specified positions removed and indices remapped
"""
if not sparse_coo_tensor.is_sparse or sparse_coo_tensor.layout != torch.sparse_coo:
raise TypeError("sparse_coo_tensor must be a torch.sparse_coo_tensor.")
if sparse_coo_tensor.dim() != 2:
raise ValueError(
f"sparse tensors with more than 2 dimensions are not supported, got {sparse_coo_tensor.dim()} dimensions"
)
indices = sparse_coo_tensor.indices()
values = sparse_coo_tensor.values()
num_positions = sparse_coo_tensor.size(0)
num_features = sparse_coo_tensor.size(1)
# Convert negative indices to positive and create sorted list
indexes_to_remove_list = []
for idx in indexes_to_remove:
if idx < 0:
idx = num_positions + idx
indexes_to_remove_list.append(idx)
indexes_to_remove_set = set(indexes_to_remove_list)
if max(indexes_to_remove_set) > num_positions - 1:
raise ValueError(
f"Index to remove {max(indexes_to_remove_set)} is out of bounds for tensor with size {num_positions}"
)
position_indices = indices[0]
mask = ~torch.isin(
position_indices,
torch.tensor(list(indexes_to_remove_set), device=position_indices.device),
)
filtered_indices = indices[:, mask]
new_values = values[mask]
# Create mapping from old positions to new positions
# new position = old position - count(removed positions < old position)
old_positions = filtered_indices[0]
sorted_removed = sorted(indexes_to_remove_set)
position_mapping = torch.zeros(
num_positions, dtype=torch.long, device=position_indices.device
)
removed_count = 0
removed_idx = 0
for pos in range(num_positions):
while removed_idx < len(sorted_removed) and sorted_removed[removed_idx] < pos:
removed_count += 1
removed_idx += 1
position_mapping[pos] = pos - removed_count
# Apply mapping to position indices
new_position_indices = position_mapping[old_positions]
# Construct new indices with remapped positions
new_indices = torch.stack([new_position_indices, filtered_indices[1]], dim=0)
new_num_positions = num_positions - len(indexes_to_remove_set)
return torch.sparse_coo_tensor(
new_indices, new_values, size=(new_num_positions, num_features)
).coalesce()
def max_pool(sparse_coo_tensor: torch.Tensor, axis: int) -> torch.Tensor:
"""Max pool sparse features along the specified axis.
Args:
sparse_coo_tensor: A sparse COO tensor of shape (num_positions, num_features)
axis: The axis to pool over (0 for positions, 1 for features)
Returns:
Max-pooled tensor.
"""
if not sparse_coo_tensor.is_sparse or sparse_coo_tensor.layout != torch.sparse_coo:
raise TypeError("sparse_coo_tensor must be a torch.sparse_coo_tensor.")
if sparse_coo_tensor.dim() != 2:
raise ValueError(
f"sparse tensors with more than 2 dimensions are not supported, got {sparse_coo_tensor.dim()} dimensions"
)
if axis not in (0, 1):
raise ValueError(f"axis must be 0 or 1, got {axis}")
indices = sparse_coo_tensor.indices()
values = sparse_coo_tensor.values()
if axis == 0:
# Pool over positions (axis 0), return max per feature
output_size = sparse_coo_tensor.size(1)
scatter_indices = indices[1] # feature indices
else: # axis == 1
# Pool over features (axis 1), return max per position
output_size = sparse_coo_tensor.size(0)
scatter_indices = indices[0] # position indices
result = torch.zeros(output_size, dtype=values.dtype, device=values.device)
result.scatter_reduce_(
0, scatter_indices, values, reduce="amax", include_self=False
)
return result