Add guided generation (#228)

This commit is contained in:
santiag0m
2025-04-01 17:21:33 -07:00
committed by GitHub
parent 3e7acde90a
commit eed38f2688
5 changed files with 648 additions and 13 deletions

View File

@@ -0,0 +1,405 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Guided Generation with ESM3\n",
"\n",
"Guided generation is a powerful tool that allows you to sample outputs out of ESM3 that maximize any kind of score function.\n",
"\n",
"For example, you may want to\n",
"1. Guide generations towards higher quality metrics like pTM\n",
"2. Constrain the distribution of outputs to have certain amino acid frequencies or structural attributes\n",
"3. Minimize a biophysical energy function\n",
"4. Use experimental screening data to guide designs with a regression model\n",
"\n",
"As long as your scoring function takes a protein as input and outputs a single score, you can use it to guide designs. To accomplish this, we use an implementation of derivative-free guidance inspired by Soft Value-Based Decoding described in [Uehara, et al 2024](https://arxiv.org/abs/2408.08252).\n",
"\n",
"In this notebook we will walk through a few examples to illustrate how to use guided generation. \n",
"\n",
"1. Guide towards high pTM for improved generation quality\n",
"\n",
"2. Generate a protein with no Cysteine residues\n",
"\n",
"3. Maximize protein globularity by minimizing the radius of gyration\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#@title **IMPORTANT:** run this cell first before doing 'Runtime →> Run all'\n",
"#@markdown - The latest update to Google Colab broke numpy; this is a temporary patch.\n",
"#@markdown - Note after running this cell, the session will crash (this is normal).\n",
"\n",
"import os, numpy, signal\n",
"\n",
"if numpy.__version__ != '1.26.4':\n",
" print(f\"Current numpy version {numpy.__version__} is incorrect. Installing 1.26.4...\")\n",
" os.system(\"'pip uninstall -y numpy\")\n",
" os.system(\"pip install numpy==1.26.4\")\n",
" # Restart the runtime using os.kill\n",
" os.kill(os. getpid(), signal.SIGKILL)\n",
"else:\n",
" print (\"Numpy version is correct (1.26.4)\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install git+https://github.com/evolutionaryscale/esm.git\n",
"!pip install py3dmol"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import biotite.structure as bs\n",
"import py3Dmol\n",
"\n",
"from esm.models.esm3 import ESM3\n",
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
"from esm.sdk.experimental import (\n",
" ESM3GuidedDecoding,\n",
" GuidedDecodingScoringFunction,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Creating a scoring function\n",
"\n",
"To get started with the guided generation API the only thing you need is to create a callable class that inherits from `GuidedDecodingScoringFunction`. This class should receive as input an `ESMProtein` object and output a numerical score.\n",
"\n",
"\n",
"For example, one of the computational metrics we can use to measure the quality of a generated protein structure is the Predicted Template Modelling (pTM) score, so we'll use it to create a `PTMScoringFunction`.\n",
"\n",
"Fortunately for us, every time we generate a protein using ESM3 (either locally or on Forge) we also get its pTM, so all our class needs to do when its called is to return the `ptm` attribute of its input."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create scoring function (e.g. PTM scoring function)\n",
"class PTMScoringFunction(GuidedDecodingScoringFunction):\n",
" def __call__(self, protein: ESMProtein) -> float:\n",
" # Minimal example of a scoring function that scores proteins based on their pTM score\n",
" # Given that ESM3 already has a pTM prediction head, we can directly access the pTM score\n",
" assert protein.ptm is not None, \"Protein must have pTM scores to be scored\"\n",
" return float(protein.ptm)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Initialize your client\n",
"\n",
"The guided generation is compatible with both local inference using the `ESM3` class and remote inference with the Forge client"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# To use the tokenizers and the open model you'll need to login into Hugging Face\n",
"\n",
"from huggingface_hub import notebook_login\n",
"notebook_login()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"## Locally with ESM3-open\n",
"model = ESM3.from_pretrained().to(\"cuda\")\n",
"\n",
"## On Forge with larger ESM3 models\n",
"# from getpass import getpass\n",
"\n",
"# from esm.sdk import client\n",
"\n",
"# token = getpass(\"Token from Forge console: \")\n",
"# model = client(model=\"esm3-open\", url=\"https://forge.evolutionaryscale.ai\", token=token)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## pTM Guided Generation\n",
"\n",
"Once your scoring function is defined and you have initialized your model you can create an `ESM3GuidedDecoding` instance to sample from it"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ptm_guided_decoding = ESM3GuidedDecoding(\n",
" client=model, scoring_function=PTMScoringFunction()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Start from a fully masked protein\n",
"PROTEIN_LENGTH = 256\n",
"starting_protein = ESMProtein(sequence=\"_\" * PROTEIN_LENGTH)\n",
"\n",
"# Call guided_generate\n",
"generated_protein = ptm_guided_decoding.guided_generate(\n",
" protein=starting_protein,\n",
" num_decoding_steps=len(starting_protein) // 8,\n",
" num_samples_per_step=10,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Compare against baseline with no guidance\n",
"\n",
"First we are going to sample a protein generated without any guidance. This means that, when not providing pTM guidance, we could be sampling proteins that have no clear structure."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Generate a protein WITHOUT guidance\n",
"generated_protein_no_guided: ESMProtein = model.generate(\n",
" input=starting_protein,\n",
" config=GenerationConfig(track=\"sequence\", num_steps=len(starting_protein) // 8),\n",
") # type: ignore\n",
"\n",
"# Fold\n",
"generated_protein_no_guided: ESMProtein = model.generate(\n",
" input=generated_protein_no_guided,\n",
" config=GenerationConfig(track=\"structure\", num_steps=1),\n",
") # type: ignore"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create a 1x2 grid of viewers (1 row, 2 columns)\n",
"view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))\n",
"\n",
"# Convert ESMProtein objects to ProteinChain objects\n",
"protein_chain1 = generated_protein_no_guided.to_protein_chain()\n",
"protein_chain2 = generated_protein.to_protein_chain()\n",
"\n",
"# Add models to respective panels\n",
"view.addModel(protein_chain1.to_pdb_string(), \"pdb\", viewer=(0, 0))\n",
"view.addModel(protein_chain2.to_pdb_string(), \"pdb\", viewer=(0, 1))\n",
"\n",
"# Set styles for each protein\n",
"view.setStyle({}, {\"cartoon\": {\"color\": \"spectrum\"}}, viewer=(0, 0))\n",
"view.setStyle({}, {\"cartoon\": {\"color\": \"spectrum\"}}, viewer=(0, 1))\n",
"\n",
"# Zoom and center the view\n",
"view.zoomTo()\n",
"view.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generate a Protein with No Cysteines\n",
"\n",
"Guided generation is not constrained to structural metrics, you can also use it to guide the sequence generation.\n",
"\n",
"For example, we can create a `NoCysteineScoringFunction` that penalizes the protein if it contains Cysteine residues"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class NoCysteineScoringFunction(GuidedDecodingScoringFunction):\n",
" def __call__(self, protein: ESMProtein) -> float:\n",
" # Penalize proteins that contain cysteine\n",
" assert protein.sequence is not None, \"Protein must have a sequence to be scored\"\n",
" # Note that we use a negative score here, to discourage the presence of cysteine\n",
" return -protein.sequence.count(\"C\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"no_cysteine_guided_decoding = ESM3GuidedDecoding(\n",
" client=model, scoring_function=NoCysteineScoringFunction()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"no_cysteine_protein = no_cysteine_guided_decoding.guided_generate(\n",
" protein=starting_protein,\n",
" num_decoding_steps=len(starting_protein) // 8,\n",
" num_samples_per_step=10,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's check our sequence!\n",
"\n",
"If guided generation converged to `score == 0.00`, the resulting protein should contain no Cysteine residues"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"assert no_cysteine_protein.sequence is not None, \"Protein must have a sequence\"\n",
"print(no_cysteine_protein.sequence)\n",
"print(f\"Number of cysteine residues: {no_cysteine_protein.sequence.count('C')}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Maximize Globularity\n",
"\n",
"We use the radius of gyration as a proxy to maximize globularity, we also encourage generations to have high pTM"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class RadiousOfGyrationScoringFunction(GuidedDecodingScoringFunction):\n",
" def __call__(self, protein: ESMProtein) -> float:\n",
" score = -1 * self.radius_of_gyration(protein)\n",
"\n",
" assert protein.ptm is not None, \"Protein must have pTM scores to be scored\"\n",
" if protein.ptm < 0.5:\n",
" # Penalize proteins with low pTM scores\n",
" score = score * 2\n",
"\n",
" return score\n",
"\n",
" @staticmethod\n",
" def radius_of_gyration(protein: ESMProtein) -> float:\n",
" protein_chain = protein.to_protein_chain()\n",
" arr = protein_chain.atom_array_no_insertions\n",
" return bs.gyration_radius(arr)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"radius_guided_decoding = ESM3GuidedDecoding(\n",
" client=model, scoring_function=RadiousOfGyrationScoringFunction()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"radius_guided_protein = radius_guided_decoding.guided_generate(\n",
" protein=starting_protein,\n",
" num_decoding_steps=len(starting_protein) // 8,\n",
" num_samples_per_step=10,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"view = py3Dmol.view(width=800, height=400)\n",
"view.addModel(radius_guided_protein.to_pdb_string(), \"pdb\")\n",
"view.setStyle({\"cartoon\": {\"color\": \"spectrum\"}})\n",
"view.zoomTo()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -1 +1,2 @@
__version__ = "3.1.5post1"
__version__ = "3.1.7"

209
esm/sdk/experimental.py Normal file
View File

@@ -0,0 +1,209 @@
from abc import ABC, abstractmethod
import attr
import torch
from tqdm import tqdm
from esm.models.esm3 import ESM3
from esm.sdk.api import (
ESM3InferenceClient,
ESMProtein,
ESMProteinError,
ESMProteinTensor,
SamplingConfig,
SamplingTrackConfig,
)
from esm.sdk.forge import ESM3ForgeInferenceClient
from esm.tokenization import get_esm3_model_tokenizers
class GuidedDecodingScoringFunction(ABC):
@abstractmethod
def __call__(self, protein: ESMProtein) -> float:
pass
class ESM3GuidedDecoding:
"""This class can be used to perform derivative-free guided decoding, based on
the method described in "Derivative-Free Guidance in Continuous and Discrete Diffusion Models with Soft Value-Based Decoding"
https://arxiv.org/abs/2408.08252
"""
def __init__(
self,
client: ESM3InferenceClient,
scoring_function: GuidedDecodingScoringFunction,
):
if isinstance(client, ESM3):
self.tokenizers = client.tokenizers
elif isinstance(client, ESM3ForgeInferenceClient):
self.tokenizers = get_esm3_model_tokenizers(client.model)
else:
raise ValueError(
"client must be an instance of ESM3 or ESM3ForgeInferenceClient"
)
self.client = client
self.scoring_function = scoring_function
def guided_generate(
self,
protein: ESMProtein,
num_decoding_steps: int,
num_samples_per_step: int,
denoised_prediction_temperature: float = 0.0,
track: str = "sequence",
verbose: bool = True,
) -> ESMProtein:
protein_tensor = self.client.encode(protein)
assert not isinstance(protein_tensor, ESMProteinError)
if track == "structure":
protein_tensor = self.maybe_add_default_structure_tokens(protein_tensor)
num_masked_positions = self.get_number_of_masked_positions(
protein_tensor, track=track
)
num_positions_to_unmask = num_masked_positions // num_decoding_steps
current_score = -1
if verbose:
pbar = tqdm(range(num_decoding_steps), desc="Current score: -1")
else:
pbar = range(num_decoding_steps)
for step in pbar:
if step == num_decoding_steps - 1:
# At the last step, unmask all remaining positions
num_positions_to_unmask = self.get_number_of_masked_positions(
protein_tensor, track=track
)
samples = []
scores = []
for _ in range(num_samples_per_step):
sample = self.randomly_unmask_positions(
protein_tensor, num_positions_to_unmask, track=track
)
scores.append(
self.reward_function(
sample,
denoised_prediction_temperature=denoised_prediction_temperature,
)
)
samples.append(sample)
# Select best scoring sample
best_sample = samples[scores.index(max(scores))]
current_score = max(scores)
protein_tensor = best_sample
if verbose:
pbar.set_description(f"Current score: {current_score:.2f}") # type: ignore
# Fully predict and decode final protein
protein_tensor_output = self.client.forward_and_sample(
protein_tensor,
SamplingConfig(
sequence=SamplingTrackConfig(temperature=0.0),
structure=SamplingTrackConfig(temperature=0.0),
),
)
assert not isinstance(protein_tensor_output, ESMProteinError)
protein_tensor = protein_tensor_output.protein_tensor
decoded_protein = self.client.decode(protein_tensor)
assert not isinstance(decoded_protein, ESMProteinError)
return decoded_protein
def reward_function(
self,
protein_tensor: ESMProteinTensor,
denoised_prediction_temperature: float = 0.0,
) -> float:
denoised_protein = self.predict_denoised(
protein_tensor, temperature=denoised_prediction_temperature
)
return self.scoring_function(denoised_protein)
def get_number_of_masked_positions(
self, protein_tensor: ESMProteinTensor, track: str = "sequence"
) -> int:
assert isinstance(protein_tensor, ESMProteinTensor)
track_tensor = getattr(protein_tensor, track)
track_tokenizer = getattr(self.tokenizers, track)
is_mask = track_tensor == track_tokenizer.mask_token_id
return is_mask.sum().item() # type: ignore
def randomly_unmask_positions(
self,
protein_tensor: ESMProteinTensor,
num_positions_to_unmask: int,
temperature: float = 1.0,
track: str = "sequence",
) -> ESMProteinTensor:
track_tensor = getattr(protein_tensor, track)
assert track_tensor is not None
protein_tensor = attr.evolve(protein_tensor)
setattr(protein_tensor, track, track_tensor.clone())
track_tensor = getattr(protein_tensor, track)
track_tokenizer = getattr(self.tokenizers, track)
is_mask = track_tensor == track_tokenizer.mask_token_id
num_masked_positions = is_mask.sum().item()
if num_positions_to_unmask > num_masked_positions:
num_positions_to_unmask = num_masked_positions # type: ignore
mask_indices = is_mask.nonzero(as_tuple=False)
mask_indices = mask_indices[torch.randperm(mask_indices.size(0))]
mask_indices = mask_indices[:num_positions_to_unmask]
sampling_config = SamplingConfig()
setattr(sampling_config, track, SamplingTrackConfig(temperature=temperature))
denoised_protein_tensor_output = self.client.forward_and_sample(
protein_tensor, sampling_configuration=sampling_config
)
assert not isinstance(denoised_protein_tensor_output, ESMProteinError)
denoised_protein_tensor = denoised_protein_tensor_output.protein_tensor
output_track_tensor = getattr(denoised_protein_tensor, track)
assert output_track_tensor is not None
track_tensor[mask_indices] = output_track_tensor[mask_indices]
setattr(protein_tensor, track, track_tensor)
return protein_tensor
def predict_denoised(
self, protein_tensor: ESMProteinTensor, temperature: float = 0.0
) -> ESMProtein:
denoised_protein_tensor_output = self.client.forward_and_sample(
protein_tensor,
sampling_configuration=SamplingConfig(
sequence=SamplingTrackConfig(temperature=temperature),
structure=SamplingTrackConfig(temperature=temperature),
),
)
assert not isinstance(denoised_protein_tensor_output, ESMProteinError)
denoised_protein_tensor = denoised_protein_tensor_output.protein_tensor
denoised_protein = self.client.decode(denoised_protein_tensor)
assert not isinstance(denoised_protein, ESMProteinError)
return denoised_protein
def maybe_add_default_structure_tokens(
self, protein_tensor: ESMProteinTensor
) -> ESMProteinTensor:
empty_protein_tensor = ESMProteinTensor.empty(
len(protein_tensor) - 2,
tokenizers=self.tokenizers,
device=protein_tensor.device,
)
if protein_tensor.structure is None:
setattr(protein_tensor, "structure", empty_protein_tensor.structure)
else:
print("Warning: structure already exists in protein_tensor")
return protein_tensor

View File

@@ -134,6 +134,7 @@ def _make_masked_inputs(
track: str, sequence_length: int, tokenizers: TokenizerCollectionProtocol
):
get_tokenizer: Callable[[str], EsmTokenizerBase] = lambda s: getattr(tokenizers, s)
has_tokenizer: Callable[[str], bool] = lambda s: hasattr(tokenizers, s)
if track == "coordinates":
dims = (sequence_length, 3, 3)
@@ -155,12 +156,15 @@ def _make_masked_inputs(
masked_tokens = torch.full(dims, 0.0)
elif track == "attention_mask":
masked_tokens = torch.full(dims, 1, dtype=torch.bool)
else:
elif has_tokenizer(track):
masked_tokens = torch.full(
dims, get_tokenizer(track).mask_token_id, dtype=torch.long
)
masked_tokens[0] = get_tokenizer(track).bos_token_id
masked_tokens[-1] = get_tokenizer(track).eos_token_id
else:
# Does not know how to create the dummy all masked input.
return None
return masked_tokens
@@ -173,15 +177,32 @@ def _stack_protein_tensors(
) -> _BatchedESMProteinTensor:
o = _BatchedESMProteinTensor()
def _maybe_mock_input(fn, t, l):
if t is not None:
return t
# Try create dummy masked input for this prompt.
t = _make_masked_inputs(fn, l, tokenizers)
if t is not None:
t = t.to(device)
return t
def _stack_field(fn: str):
tensors = [getattr(tokens, fn) for tokens in input_tokens]
# Create all mask mock inputs for any tensors that are None.
tensors = [
t if t is not None else _make_masked_inputs(fn, l, tokenizers).to(device)
for t, l in zip(tensors, sequence_lengths)
_maybe_mock_input(fn, t, l) for t, l in zip(tensors, sequence_lengths)
]
# Handle any track that has all None as the input.
# We can't meaningfully stack tensors in this case, so simply batched
# them as None in _BatchedESMProteinTensor.
if all([t is None for t in tensors]):
setattr(o, fn, None)
return
if fn == "coordinates":
mask_token_id = torch.inf
else:
@@ -191,7 +212,8 @@ def _stack_protein_tensors(
o,
fn,
stack_variable_length_tensors(
sequences=tensors, constant_value=mask_token_id
sequences=tensors, # type: ignore
constant_value=mask_token_id,
),
)

View File

@@ -1,17 +1,15 @@
[project]
name = "esm"
version = "3.1.6"
version = "3.1.7"
description = "EvolutionaryScale open model repository"
readme = "README.md"
requires-python = ">=3.10"
license = {file = "LICENSE.txt"}
license = { file = "LICENSE.txt" }
authors = [
{name = "EvolutionaryScale Team"}
]
authors = [{ name = "EvolutionaryScale Team" }]
maintainers = [
{name = "Zeming Lin", email = "zeming+esm@evolutionaryscale.ai" }
{ name = "Zeming Lin", email = "zeming+esm@evolutionaryscale.ai" },
]
classifiers = [
@@ -36,11 +34,11 @@ dependencies = [
"pandas",
"cloudpathlib",
"tenacity",
"zstd"
"zstd",
]
[tool.setuptools]
package-dir = {"" = "."}
package-dir = { "" = "." }
include-package-data = true
[tool.setuptools.packages.find]