Files
esm/cookbook/tutorials/5_guided_generation.ipynb
2025-04-01 17:21:33 -07:00

406 lines
13 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Guided Generation with ESM3\n",
"\n",
"Guided generation is a powerful tool that allows you to sample outputs out of ESM3 that maximize any kind of score function.\n",
"\n",
"For example, you may want to\n",
"1. Guide generations towards higher quality metrics like pTM\n",
"2. Constrain the distribution of outputs to have certain amino acid frequencies or structural attributes\n",
"3. Minimize a biophysical energy function\n",
"4. Use experimental screening data to guide designs with a regression model\n",
"\n",
"As long as your scoring function takes a protein as input and outputs a single score, you can use it to guide designs. To accomplish this, we use an implementation of derivative-free guidance inspired by Soft Value-Based Decoding described in [Uehara, et al 2024](https://arxiv.org/abs/2408.08252).\n",
"\n",
"In this notebook we will walk through a few examples to illustrate how to use guided generation. \n",
"\n",
"1. Guide towards high pTM for improved generation quality\n",
"\n",
"2. Generate a protein with no Cysteine residues\n",
"\n",
"3. Maximize protein globularity by minimizing the radius of gyration\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#@title **IMPORTANT:** run this cell first before doing 'Runtime →> Run all'\n",
"#@markdown - The latest update to Google Colab broke numpy; this is a temporary patch.\n",
"#@markdown - Note after running this cell, the session will crash (this is normal).\n",
"\n",
"import os, numpy, signal\n",
"\n",
"if numpy.__version__ != '1.26.4':\n",
" print(f\"Current numpy version {numpy.__version__} is incorrect. Installing 1.26.4...\")\n",
" os.system(\"'pip uninstall -y numpy\")\n",
" os.system(\"pip install numpy==1.26.4\")\n",
" # Restart the runtime using os.kill\n",
" os.kill(os. getpid(), signal.SIGKILL)\n",
"else:\n",
" print (\"Numpy version is correct (1.26.4)\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install git+https://github.com/evolutionaryscale/esm.git\n",
"!pip install py3dmol"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import biotite.structure as bs\n",
"import py3Dmol\n",
"\n",
"from esm.models.esm3 import ESM3\n",
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
"from esm.sdk.experimental import (\n",
" ESM3GuidedDecoding,\n",
" GuidedDecodingScoringFunction,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Creating a scoring function\n",
"\n",
"To get started with the guided generation API the only thing you need is to create a callable class that inherits from `GuidedDecodingScoringFunction`. This class should receive as input an `ESMProtein` object and output a numerical score.\n",
"\n",
"\n",
"For example, one of the computational metrics we can use to measure the quality of a generated protein structure is the Predicted Template Modelling (pTM) score, so we'll use it to create a `PTMScoringFunction`.\n",
"\n",
"Fortunately for us, every time we generate a protein using ESM3 (either locally or on Forge) we also get its pTM, so all our class needs to do when its called is to return the `ptm` attribute of its input."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create scoring function (e.g. PTM scoring function)\n",
"class PTMScoringFunction(GuidedDecodingScoringFunction):\n",
" def __call__(self, protein: ESMProtein) -> float:\n",
" # Minimal example of a scoring function that scores proteins based on their pTM score\n",
" # Given that ESM3 already has a pTM prediction head, we can directly access the pTM score\n",
" assert protein.ptm is not None, \"Protein must have pTM scores to be scored\"\n",
" return float(protein.ptm)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Initialize your client\n",
"\n",
"The guided generation is compatible with both local inference using the `ESM3` class and remote inference with the Forge client"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# To use the tokenizers and the open model you'll need to login into Hugging Face\n",
"\n",
"from huggingface_hub import notebook_login\n",
"notebook_login()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"## Locally with ESM3-open\n",
"model = ESM3.from_pretrained().to(\"cuda\")\n",
"\n",
"## On Forge with larger ESM3 models\n",
"# from getpass import getpass\n",
"\n",
"# from esm.sdk import client\n",
"\n",
"# token = getpass(\"Token from Forge console: \")\n",
"# model = client(model=\"esm3-open\", url=\"https://forge.evolutionaryscale.ai\", token=token)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## pTM Guided Generation\n",
"\n",
"Once your scoring function is defined and you have initialized your model you can create an `ESM3GuidedDecoding` instance to sample from it"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ptm_guided_decoding = ESM3GuidedDecoding(\n",
" client=model, scoring_function=PTMScoringFunction()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Start from a fully masked protein\n",
"PROTEIN_LENGTH = 256\n",
"starting_protein = ESMProtein(sequence=\"_\" * PROTEIN_LENGTH)\n",
"\n",
"# Call guided_generate\n",
"generated_protein = ptm_guided_decoding.guided_generate(\n",
" protein=starting_protein,\n",
" num_decoding_steps=len(starting_protein) // 8,\n",
" num_samples_per_step=10,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Compare against baseline with no guidance\n",
"\n",
"First we are going to sample a protein generated without any guidance. This means that, when not providing pTM guidance, we could be sampling proteins that have no clear structure."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Generate a protein WITHOUT guidance\n",
"generated_protein_no_guided: ESMProtein = model.generate(\n",
" input=starting_protein,\n",
" config=GenerationConfig(track=\"sequence\", num_steps=len(starting_protein) // 8),\n",
") # type: ignore\n",
"\n",
"# Fold\n",
"generated_protein_no_guided: ESMProtein = model.generate(\n",
" input=generated_protein_no_guided,\n",
" config=GenerationConfig(track=\"structure\", num_steps=1),\n",
") # type: ignore"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create a 1x2 grid of viewers (1 row, 2 columns)\n",
"view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))\n",
"\n",
"# Convert ESMProtein objects to ProteinChain objects\n",
"protein_chain1 = generated_protein_no_guided.to_protein_chain()\n",
"protein_chain2 = generated_protein.to_protein_chain()\n",
"\n",
"# Add models to respective panels\n",
"view.addModel(protein_chain1.to_pdb_string(), \"pdb\", viewer=(0, 0))\n",
"view.addModel(protein_chain2.to_pdb_string(), \"pdb\", viewer=(0, 1))\n",
"\n",
"# Set styles for each protein\n",
"view.setStyle({}, {\"cartoon\": {\"color\": \"spectrum\"}}, viewer=(0, 0))\n",
"view.setStyle({}, {\"cartoon\": {\"color\": \"spectrum\"}}, viewer=(0, 1))\n",
"\n",
"# Zoom and center the view\n",
"view.zoomTo()\n",
"view.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generate a Protein with No Cysteines\n",
"\n",
"Guided generation is not constrained to structural metrics, you can also use it to guide the sequence generation.\n",
"\n",
"For example, we can create a `NoCysteineScoringFunction` that penalizes the protein if it contains Cysteine residues"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class NoCysteineScoringFunction(GuidedDecodingScoringFunction):\n",
" def __call__(self, protein: ESMProtein) -> float:\n",
" # Penalize proteins that contain cysteine\n",
" assert protein.sequence is not None, \"Protein must have a sequence to be scored\"\n",
" # Note that we use a negative score here, to discourage the presence of cysteine\n",
" return -protein.sequence.count(\"C\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"no_cysteine_guided_decoding = ESM3GuidedDecoding(\n",
" client=model, scoring_function=NoCysteineScoringFunction()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"no_cysteine_protein = no_cysteine_guided_decoding.guided_generate(\n",
" protein=starting_protein,\n",
" num_decoding_steps=len(starting_protein) // 8,\n",
" num_samples_per_step=10,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's check our sequence!\n",
"\n",
"If guided generation converged to `score == 0.00`, the resulting protein should contain no Cysteine residues"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"assert no_cysteine_protein.sequence is not None, \"Protein must have a sequence\"\n",
"print(no_cysteine_protein.sequence)\n",
"print(f\"Number of cysteine residues: {no_cysteine_protein.sequence.count('C')}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Maximize Globularity\n",
"\n",
"We use the radius of gyration as a proxy to maximize globularity, we also encourage generations to have high pTM"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class RadiousOfGyrationScoringFunction(GuidedDecodingScoringFunction):\n",
" def __call__(self, protein: ESMProtein) -> float:\n",
" score = -1 * self.radius_of_gyration(protein)\n",
"\n",
" assert protein.ptm is not None, \"Protein must have pTM scores to be scored\"\n",
" if protein.ptm < 0.5:\n",
" # Penalize proteins with low pTM scores\n",
" score = score * 2\n",
"\n",
" return score\n",
"\n",
" @staticmethod\n",
" def radius_of_gyration(protein: ESMProtein) -> float:\n",
" protein_chain = protein.to_protein_chain()\n",
" arr = protein_chain.atom_array_no_insertions\n",
" return bs.gyration_radius(arr)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"radius_guided_decoding = ESM3GuidedDecoding(\n",
" client=model, scoring_function=RadiousOfGyrationScoringFunction()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"radius_guided_protein = radius_guided_decoding.guided_generate(\n",
" protein=starting_protein,\n",
" num_decoding_steps=len(starting_protein) // 8,\n",
" num_samples_per_step=10,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"view = py3Dmol.view(width=800, height=400)\n",
"view.addModel(radius_guided_protein.to_pdb_string(), \"pdb\")\n",
"view.setStyle({\"cartoon\": {\"color\": \"spectrum\"}})\n",
"view.zoomTo()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}