diff --git a/cookbook/snippets/sae.py b/cookbook/snippets/sae.py new file mode 100644 index 0000000..440c1bc --- /dev/null +++ b/cookbook/snippets/sae.py @@ -0,0 +1,70 @@ +import numpy as np +import torch + +from esm.sdk import batch_executor +from esm.sdk.api import ( + ESMProtein, + ESMProteinError, + LogitsConfig, + SAEConfig, +) +from esm.sdk.forge import ESMCForgeInferenceClient + +from cookbook.snippets.sparse_utils import max_pool, remove_indexes + + +def get_sae_features_single( + client: ESMCForgeInferenceClient, + sae_config: SAEConfig, + sequence: str, + pool: bool = True, +) -> torch.tensor: + protein = ESMProtein(sequence=sequence) + protein_tensor = client.encode(protein) + if isinstance(protein_tensor, ESMProteinError): + raise ValueError( + f"Error encoding sequence {sequence}: {protein_tensor.error_msg}" + ) + + # We wrap the SAEConfig in the LogitsConfig, which is normally used to return embeddings and hidden states. + output = client.logits( + protein_tensor, + config=LogitsConfig(sae_config=sae_config), + return_bytes=False, + ) + if isinstance(output, ESMProteinError): + raise ValueError( + f"Error getting SAE features for sequence {sequence}: {output.error_msg}" + ) + if output.sae_outputs is None: + raise ValueError(f"SAE outputs missing for sequence {sequence}: {output}") + sae_tensor = output.sae_outputs[sae_config.model] + + if pool: + # Remove BOS / EOS tokens before pooling. + sae_features = remove_indexes(sae_tensor, {0, -1}) + pooled_sae_features = max_pool(sae_features, axis=0) + return pooled_sae_features + else: + return sae_tensor + + +def get_sae_features( + client: ESMCForgeInferenceClient, + sae_config: SAEConfig, + sequences: list[str], + pool: bool = True, +) -> list[np.ndarray]: + with batch_executor() as executor: + results = executor.execute_batch( + user_func=get_sae_features_single, + client=client, + sae_config=sae_config, + sequence=sequences, + pool=pool, + ) + # Re-raise any errors from the batch + for result in results: + if isinstance(result, Exception): + raise result + return results diff --git a/cookbook/snippets/sae_example.py b/cookbook/snippets/sae_example.py new file mode 100644 index 0000000..1fdcbe8 --- /dev/null +++ b/cookbook/snippets/sae_example.py @@ -0,0 +1,41 @@ +import os + +from esm.sdk.api import SAEConfig +from esm.sdk.forge import ESMCForgeInferenceClient + +from cookbook.snippets.sae import get_sae_features, get_sae_features_single +from cookbook.snippets.sparse_utils import remove_indexes + +# Create ESMC 600M client +client = ESMCForgeInferenceClient( + model="esmc-600m-2024-12", + url="https://forge.evolutionaryscale.ai", + token=os.environ["ESM_API_KEY"], +) + +# normalize feature activations by TF-IDF. Upweights activations +# of more highly specific features +sae_config = SAEConfig( + model="esmc-600m-2024-12_k64_codebook16384_layer27", + normalize_features=True, +) + +# Create a protein +sequence = "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQATHVDQWDWEWAGIKATEAFLPDYPDLDA" +sequences = [sequence] * 10 + +# get unpooled features for a single sequence +unpooled_features = get_sae_features_single(client, sae_config, sequence, pool=False) +print(f"Got unpooled SAE features with shape {unpooled_features.shape}") +print(f"is_sparse: {unpooled_features.is_sparse}") +print(f"layout: {unpooled_features.layout}") + +# To remove bos/eos tokens efficiently from sparse tensors, we use a utility +unpooled_features = remove_indexes(unpooled_features, {0, -1}) +print(f"Unpooled SAE features after removing BOS/EOS have shape {unpooled_features.shape}") + +# get pooled features for a batch +# this function pools by default to save memory. +features = get_sae_features(client, sae_config, sequences) +print(f"Got SAE features for {len(features)} sequences, each with shape {features[0].shape}") + diff --git a/cookbook/snippets/sparse_utils.py b/cookbook/snippets/sparse_utils.py new file mode 100644 index 0000000..594b922 --- /dev/null +++ b/cookbook/snippets/sparse_utils.py @@ -0,0 +1,123 @@ + +from typing import Iterable + +import torch + +def remove_indexes( + sparse_coo_tensor: torch.Tensor, indexes_to_remove: Iterable[int] +) -> torch.Tensor: + """Remove entries at specified position indexes from sparse features. + + This function removes positions and remaps the remaining indices to be contiguous. + For example, if we remove position 1 from a tensor with positions [0, 1, 2, 3], + the result will have positions [0, 1, 2] (where old position 2 becomes new position 1). + + For example, remove_indexes(x, [0, -1]) will return the equivalent of tensor.to_dense().numpy()[1:-1] + + Args: + sparse_coo_tensor: A sparse COO tensor of shape (num_positions, num_features) + indexes_to_remove: Iterable of position indexes to remove (supports negative indexing) + + Returns: + A new sparse COO tensor with the specified positions removed and indices remapped + """ + if not sparse_coo_tensor.is_sparse or sparse_coo_tensor.layout != torch.sparse_coo: + raise TypeError("sparse_coo_tensor must be a torch.sparse_coo_tensor.") + + if sparse_coo_tensor.dim() != 2: + raise ValueError( + f"sparse tensors with more than 2 dimensions are not supported, got {sparse_coo_tensor.dim()} dimensions" + ) + + indices = sparse_coo_tensor.indices() + values = sparse_coo_tensor.values() + num_positions = sparse_coo_tensor.size(0) + num_features = sparse_coo_tensor.size(1) + + # Convert negative indices to positive and create sorted list + indexes_to_remove_list = [] + for idx in indexes_to_remove: + if idx < 0: + idx = num_positions + idx + indexes_to_remove_list.append(idx) + indexes_to_remove_set = set(indexes_to_remove_list) + + if max(indexes_to_remove_set) > num_positions - 1: + raise ValueError( + f"Index to remove {max(indexes_to_remove_set)} is out of bounds for tensor with size {num_positions}" + ) + + position_indices = indices[0] + mask = ~torch.isin( + position_indices, + torch.tensor(list(indexes_to_remove_set), device=position_indices.device), + ) + filtered_indices = indices[:, mask] + new_values = values[mask] + + # Create mapping from old positions to new positions + # new position = old position - count(removed positions < old position) + old_positions = filtered_indices[0] + sorted_removed = sorted(indexes_to_remove_set) + position_mapping = torch.zeros( + num_positions, dtype=torch.long, device=position_indices.device + ) + removed_count = 0 + removed_idx = 0 + + for pos in range(num_positions): + while removed_idx < len(sorted_removed) and sorted_removed[removed_idx] < pos: + removed_count += 1 + removed_idx += 1 + position_mapping[pos] = pos - removed_count + + # Apply mapping to position indices + new_position_indices = position_mapping[old_positions] + + # Construct new indices with remapped positions + new_indices = torch.stack([new_position_indices, filtered_indices[1]], dim=0) + + new_num_positions = num_positions - len(indexes_to_remove_set) + return torch.sparse_coo_tensor( + new_indices, new_values, size=(new_num_positions, num_features) + ).coalesce() + + +def max_pool(sparse_coo_tensor: torch.Tensor, axis: int) -> torch.Tensor: + """Max pool sparse features along the specified axis. + + Args: + sparse_coo_tensor: A sparse COO tensor of shape (num_positions, num_features) + axis: The axis to pool over (0 for positions, 1 for features) + + Returns: + Max-pooled tensor. + """ + if not sparse_coo_tensor.is_sparse or sparse_coo_tensor.layout != torch.sparse_coo: + raise TypeError("sparse_coo_tensor must be a torch.sparse_coo_tensor.") + + if sparse_coo_tensor.dim() != 2: + raise ValueError( + f"sparse tensors with more than 2 dimensions are not supported, got {sparse_coo_tensor.dim()} dimensions" + ) + + if axis not in (0, 1): + raise ValueError(f"axis must be 0 or 1, got {axis}") + + indices = sparse_coo_tensor.indices() + values = sparse_coo_tensor.values() + + if axis == 0: + # Pool over positions (axis 0), return max per feature + output_size = sparse_coo_tensor.size(1) + scatter_indices = indices[1] # feature indices + else: # axis == 1 + # Pool over features (axis 1), return max per position + output_size = sparse_coo_tensor.size(0) + scatter_indices = indices[0] # position indices + + result = torch.zeros(output_size, dtype=values.dtype, device=values.device) + result.scatter_reduce_( + 0, scatter_indices, values, reduce="amax", include_self=False + ) + return result \ No newline at end of file