Files
esm/cookbook/tutorials/esm3_guided_generation.ipynb
2026-05-27 08:42:00 -04:00

445 lines
14 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# [Tutorial](https://github.com/biohub/esm/tree/main/cookbook/tutorials): Guided Generation with ESM3\n",
"\n",
"Guided generation is a powerful tool that allows you to sample outputs out of ESM3 that maximize any kind of score function.\n",
"\n",
"For example, you may want to\n",
"1. Guide generations towards higher quality metrics like pTM\n",
"2. Constrain the distribution of outputs to have certain amino acid frequencies or structural attributes\n",
"3. Minimize a biophysical energy function\n",
"4. Use experimental screening data to guide designs with a regression model\n",
"\n",
"As long as your scoring function takes a protein as input and outputs a single score, you can use it to guide designs. To accomplish this, we use an implementation of derivative-free guidance inspired by Soft Value-Based Decoding described in [Li, et al 2024](https://arxiv.org/abs/2408.08252) and constrained optimization using the Modified Differential Method of Multipliers from [Platt & Barr 1987](https://proceedings.neurips.cc/paper_files/paper/1987/file/a1126573153ad7e9f44ba80e99316482-Paper.pdf)\n",
"\n",
"In this notebook we will walk through a few examples to illustrate how to use guided generation. \n",
"\n",
"1. Guide towards high pTM for improved generation quality\n",
"2. Generate a protein with no cysteine (C) residues\n",
"3. Maximize protein globularity by minimizing the radius of gyration, while keeping pTM high\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# # If you are working in colab, uncomment these lines to install dependencies\n",
"# !pip install esm@git+https://github.com/Biohub/esm.git@c94ed8d\n",
"# !pip install py3dmol"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import biotite.structure as bs\n",
"import py3Dmol\n",
"\n",
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
"from esm.sdk.experimental import ESM3GuidedDecoding, GuidedDecodingScoringFunction"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Creating a scoring function\n",
"\n",
"To get started with the guided generation API the only thing you need is to create a callable class that inherits from `GuidedDecodingScoringFunction`. This class should receive as input an `ESMProtein` object and output a numerical score.\n",
"\n",
"\n",
"For example, one of the computational metrics we can use to measure the quality of a generated protein structure is the Predicted Template Modelling (pTM) score, so we'll use it to create a `PTMScoringFunction`.\n",
"\n",
"Fortunately for us, every time we generate a protein using ESM3 (either locally or on Biohub) we also get its pTM, so all our class needs to do when its called is to return the `ptm` attribute of its input."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create scoring function (e.g. PTM scoring function)\n",
"class PTMScoringFunction(GuidedDecodingScoringFunction):\n",
" def __call__(self, protein: ESMProtein) -> float:\n",
" # Minimal example of a scoring function that scores proteins based on their pTM score\n",
" # Given that ESM3 already has a pTM prediction head, we can directly access the pTM score\n",
" assert protein.ptm is not None, \"Protein must have pTM scores to be scored\"\n",
" return float(protein.ptm)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Initialize your client\n",
"\n",
"The guided generation is compatible with both local inference using the `ESM3` class and remote inference with the Biohub client"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# To use the tokenizers and the open model you'll need to login into Hugging Face\n",
"\n",
"from huggingface_hub import notebook_login\n",
"\n",
"notebook_login()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"## Locally with ESM3-open\n",
"# from esm.models.esm3 import ESM3\n",
"# model = ESM3.from_pretrained().to(\"cuda\")\n",
"\n",
"## On Biohub with larger ESM3 models\n",
"from getpass import getpass\n",
"\n",
"from esm.sdk import client\n",
"\n",
"token = getpass(\"Token from Biohub: \")\n",
"model = client(model=\"esm3-medium-2024-08\", url=\"https://biohub.ai\", token=token)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Guide towards high pTM for improved generation quality\n",
"\n",
"Once your scoring function is defined and you have initialized your model you can create an `ESM3GuidedDecoding` instance to sample from it"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ptm_guided_decoding = ESM3GuidedDecoding(\n",
" client=model, scoring_function=PTMScoringFunction()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Start from a fully masked protein\n",
"PROTEIN_LENGTH = 256\n",
"starting_protein = ESMProtein(sequence=\"_\" * PROTEIN_LENGTH)\n",
"\n",
"# Call guided_generate\n",
"generated_protein = ptm_guided_decoding.guided_generate(\n",
" protein=starting_protein,\n",
" num_decoding_steps=len(starting_protein) // 8,\n",
" num_samples_per_step=10,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Compare against baseline with no guidance\n",
"\n",
"First we are going to sample a protein generated without any guidance. This means that, when not providing pTM guidance, we could be sampling proteins that have no clear structure."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Generate a protein WITHOUT guidance\n",
"generated_protein_no_guided: ESMProtein = model.generate(\n",
" input=starting_protein,\n",
" config=GenerationConfig(track=\"sequence\", num_steps=len(starting_protein) // 8),\n",
") # type: ignore\n",
"\n",
"# Fold\n",
"generated_protein_no_guided: ESMProtein = model.generate(\n",
" input=generated_protein_no_guided,\n",
" config=GenerationConfig(track=\"structure\", num_steps=1),\n",
") # type: ignore"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create a 1x2 grid of viewers (1 row, 2 columns)\n",
"view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))\n",
"\n",
"# Convert ESMProtein objects to ProteinChain objects\n",
"protein_chain1 = generated_protein_no_guided.to_protein_chain()\n",
"protein_chain2 = generated_protein.to_protein_chain()\n",
"\n",
"# Add models to respective panels\n",
"view.addModel(protein_chain1.to_pdb_string(), \"pdb\", viewer=(0, 0))\n",
"view.addModel(protein_chain2.to_pdb_string(), \"pdb\", viewer=(0, 1))\n",
"\n",
"# Set styles for each protein\n",
"view.setStyle({}, {\"cartoon\": {\"color\": \"spectrum\"}}, viewer=(0, 0))\n",
"view.setStyle({}, {\"cartoon\": {\"color\": \"spectrum\"}}, viewer=(0, 1))\n",
"\n",
"# Zoom and center the view\n",
"view.zoomTo()\n",
"view.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(f\"pTM Without guidance: {generated_protein_no_guided.ptm:.3f}\")\n",
"print(f\"pTM With guidance: {generated_protein.ptm:.3f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generate a Protein with No Cysteines\n",
"\n",
"Guided generation is not constrained to structural metrics, you can also use it to guide the sequence generation.\n",
"\n",
"For example, we can create a `NoCysteineScoringFunction` that penalizes the protein if it contains Cysteine residues"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class NoCysteineScoringFunction(GuidedDecodingScoringFunction):\n",
" def __call__(self, protein: ESMProtein) -> float:\n",
" # Penalize proteins that contain cysteine\n",
" assert protein.sequence is not None, \"Protein must have a sequence to be scored\"\n",
" # Note that we use a negative score here, to discourage the presence of cysteine\n",
" return -protein.sequence.count(\"C\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"no_cysteine_guided_decoding = ESM3GuidedDecoding(\n",
" client=model, scoring_function=NoCysteineScoringFunction()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Start from a fully masked protein\n",
"PROTEIN_LENGTH = 256\n",
"starting_protein = ESMProtein(sequence=\"_\" * PROTEIN_LENGTH)\n",
"\n",
"# Call guided_generate\n",
"no_cysteine_protein = no_cysteine_guided_decoding.guided_generate(\n",
" protein=starting_protein,\n",
" num_decoding_steps=len(starting_protein) // 8,\n",
" num_samples_per_step=10,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's check our sequence!\n",
"\n",
"If guided generation converged to `score == 0.00`, the resulting protein should contain no Cysteine residues"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"assert no_cysteine_protein.sequence is not None, \"Protein must have a sequence\"\n",
"print(no_cysteine_protein.sequence)\n",
"print(f\"Number of cysteine residues: {no_cysteine_protein.sequence.count('C')}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Maximize Globularity\n",
"\n",
"We use the radius of gyration as a proxy to maximize globularity, and we will also encourage generations to have high pTM by using constraints"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from esm.sdk.experimental import (\n",
" ConstraintType,\n",
" ESM3GuidedDecodingWithConstraints,\n",
" GenerationConstraint,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class RadiousOfGyrationScoringFunction(GuidedDecodingScoringFunction):\n",
" def __call__(self, protein: ESMProtein) -> float:\n",
" # Use the negative radius of gyration as the score to maximize\n",
" score = -1 * self.radius_of_gyration(protein)\n",
"\n",
" # Re-scale the score to be in a similar magnitude as pTM\n",
" score = score / 100.0\n",
"\n",
" return score\n",
"\n",
" @staticmethod\n",
" def radius_of_gyration(protein: ESMProtein) -> float:\n",
" protein_chain = protein.to_protein_chain()\n",
" arr = protein_chain.atom_array_no_insertions\n",
" return bs.gyration_radius(arr)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Constrain generation to have pTM > 0.75\n",
"ptm_constraint = GenerationConstraint(\n",
" scoring_function=PTMScoringFunction(),\n",
" constraint_type=ConstraintType.GREATER_EQUAL,\n",
" value=0.75,\n",
")\n",
"\n",
"radius_guided_decoding = ESM3GuidedDecodingWithConstraints(\n",
" client=model,\n",
" scoring_function=RadiousOfGyrationScoringFunction(),\n",
" constraints=[ptm_constraint], # Add list of constraints\n",
" damping=1.0, # Damping factor for the MMDM algorithm\n",
" learning_rate=10.0, # Learning rate for the MMDM algorithm\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Start from a fully masked protein\n",
"PROTEIN_LENGTH = 256\n",
"starting_protein = ESMProtein(sequence=\"_\" * PROTEIN_LENGTH)\n",
"\n",
"# Call guided_generate\n",
"radius_guided_protein = radius_guided_decoding.guided_generate(\n",
" protein=starting_protein,\n",
" num_decoding_steps=len(starting_protein) // 8,\n",
" num_samples_per_step=10,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Visualize the trajectory of the constrained generation\n",
"radius_guided_decoding.visualize_latest_trajectory()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Visualize the generated protein\n",
"view = py3Dmol.view(width=800, height=400)\n",
"view.addModel(radius_guided_protein.to_pdb_string(), \"pdb\")\n",
"view.setStyle({\"cartoon\": {\"color\": \"spectrum\"}})\n",
"view.zoomTo()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Check pTM\n",
"radius_guided_protein.ptm"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}