mirror of
https://github.com/evolutionaryscale/esm.git
synced 2026-06-04 17:14:23 +08:00
406 lines
13 KiB
Plaintext
406 lines
13 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Guided Generation with ESM3\n",
|
|
"\n",
|
|
"Guided generation is a powerful tool that allows you to sample outputs out of ESM3 that maximize any kind of score function.\n",
|
|
"\n",
|
|
"For example, you may want to\n",
|
|
"1. Guide generations towards higher quality metrics like pTM\n",
|
|
"2. Constrain the distribution of outputs to have certain amino acid frequencies or structural attributes\n",
|
|
"3. Minimize a biophysical energy function\n",
|
|
"4. Use experimental screening data to guide designs with a regression model\n",
|
|
"\n",
|
|
"As long as your scoring function takes a protein as input and outputs a single score, you can use it to guide designs. To accomplish this, we use an implementation of derivative-free guidance inspired by Soft Value-Based Decoding described in [Uehara, et al 2024](https://arxiv.org/abs/2408.08252).\n",
|
|
"\n",
|
|
"In this notebook we will walk through a few examples to illustrate how to use guided generation. \n",
|
|
"\n",
|
|
"1. Guide towards high pTM for improved generation quality\n",
|
|
"\n",
|
|
"2. Generate a protein with no Cysteine residues\n",
|
|
"\n",
|
|
"3. Maximize protein globularity by minimizing the radius of gyration\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Imports"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"#@title **IMPORTANT:** run this cell first before doing 'Runtime →> Run all'\n",
|
|
"#@markdown - The latest update to Google Colab broke numpy; this is a temporary patch.\n",
|
|
"#@markdown - Note after running this cell, the session will crash (this is normal).\n",
|
|
"\n",
|
|
"import os, numpy, signal\n",
|
|
"\n",
|
|
"if numpy.__version__ != '1.26.4':\n",
|
|
" print(f\"Current numpy version {numpy.__version__} is incorrect. Installing 1.26.4...\")\n",
|
|
" os.system(\"'pip uninstall -y numpy\")\n",
|
|
" os.system(\"pip install numpy==1.26.4\")\n",
|
|
" # Restart the runtime using os.kill\n",
|
|
" os.kill(os. getpid(), signal.SIGKILL)\n",
|
|
"else:\n",
|
|
" print (\"Numpy version is correct (1.26.4)\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"!pip install git+https://github.com/evolutionaryscale/esm.git\n",
|
|
"!pip install py3dmol"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import biotite.structure as bs\n",
|
|
"import py3Dmol\n",
|
|
"\n",
|
|
"from esm.models.esm3 import ESM3\n",
|
|
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
|
|
"from esm.sdk.experimental import (\n",
|
|
" ESM3GuidedDecoding,\n",
|
|
" GuidedDecodingScoringFunction,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Creating a scoring function\n",
|
|
"\n",
|
|
"To get started with the guided generation API the only thing you need is to create a callable class that inherits from `GuidedDecodingScoringFunction`. This class should receive as input an `ESMProtein` object and output a numerical score.\n",
|
|
"\n",
|
|
"\n",
|
|
"For example, one of the computational metrics we can use to measure the quality of a generated protein structure is the Predicted Template Modelling (pTM) score, so we'll use it to create a `PTMScoringFunction`.\n",
|
|
"\n",
|
|
"Fortunately for us, every time we generate a protein using ESM3 (either locally or on Forge) we also get its pTM, so all our class needs to do when its called is to return the `ptm` attribute of its input."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Create scoring function (e.g. PTM scoring function)\n",
|
|
"class PTMScoringFunction(GuidedDecodingScoringFunction):\n",
|
|
" def __call__(self, protein: ESMProtein) -> float:\n",
|
|
" # Minimal example of a scoring function that scores proteins based on their pTM score\n",
|
|
" # Given that ESM3 already has a pTM prediction head, we can directly access the pTM score\n",
|
|
" assert protein.ptm is not None, \"Protein must have pTM scores to be scored\"\n",
|
|
" return float(protein.ptm)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Initialize your client\n",
|
|
"\n",
|
|
"The guided generation is compatible with both local inference using the `ESM3` class and remote inference with the Forge client"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# To use the tokenizers and the open model you'll need to login into Hugging Face\n",
|
|
"\n",
|
|
"from huggingface_hub import notebook_login\n",
|
|
"notebook_login()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"## Locally with ESM3-open\n",
|
|
"model = ESM3.from_pretrained().to(\"cuda\")\n",
|
|
"\n",
|
|
"## On Forge with larger ESM3 models\n",
|
|
"# from getpass import getpass\n",
|
|
"\n",
|
|
"# from esm.sdk import client\n",
|
|
"\n",
|
|
"# token = getpass(\"Token from Forge console: \")\n",
|
|
"# model = client(model=\"esm3-open\", url=\"https://forge.evolutionaryscale.ai\", token=token)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## pTM Guided Generation\n",
|
|
"\n",
|
|
"Once your scoring function is defined and you have initialized your model you can create an `ESM3GuidedDecoding` instance to sample from it"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"ptm_guided_decoding = ESM3GuidedDecoding(\n",
|
|
" client=model, scoring_function=PTMScoringFunction()\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Start from a fully masked protein\n",
|
|
"PROTEIN_LENGTH = 256\n",
|
|
"starting_protein = ESMProtein(sequence=\"_\" * PROTEIN_LENGTH)\n",
|
|
"\n",
|
|
"# Call guided_generate\n",
|
|
"generated_protein = ptm_guided_decoding.guided_generate(\n",
|
|
" protein=starting_protein,\n",
|
|
" num_decoding_steps=len(starting_protein) // 8,\n",
|
|
" num_samples_per_step=10,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Compare against baseline with no guidance\n",
|
|
"\n",
|
|
"First we are going to sample a protein generated without any guidance. This means that, when not providing pTM guidance, we could be sampling proteins that have no clear structure."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Generate a protein WITHOUT guidance\n",
|
|
"generated_protein_no_guided: ESMProtein = model.generate(\n",
|
|
" input=starting_protein,\n",
|
|
" config=GenerationConfig(track=\"sequence\", num_steps=len(starting_protein) // 8),\n",
|
|
") # type: ignore\n",
|
|
"\n",
|
|
"# Fold\n",
|
|
"generated_protein_no_guided: ESMProtein = model.generate(\n",
|
|
" input=generated_protein_no_guided,\n",
|
|
" config=GenerationConfig(track=\"structure\", num_steps=1),\n",
|
|
") # type: ignore"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Create a 1x2 grid of viewers (1 row, 2 columns)\n",
|
|
"view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))\n",
|
|
"\n",
|
|
"# Convert ESMProtein objects to ProteinChain objects\n",
|
|
"protein_chain1 = generated_protein_no_guided.to_protein_chain()\n",
|
|
"protein_chain2 = generated_protein.to_protein_chain()\n",
|
|
"\n",
|
|
"# Add models to respective panels\n",
|
|
"view.addModel(protein_chain1.to_pdb_string(), \"pdb\", viewer=(0, 0))\n",
|
|
"view.addModel(protein_chain2.to_pdb_string(), \"pdb\", viewer=(0, 1))\n",
|
|
"\n",
|
|
"# Set styles for each protein\n",
|
|
"view.setStyle({}, {\"cartoon\": {\"color\": \"spectrum\"}}, viewer=(0, 0))\n",
|
|
"view.setStyle({}, {\"cartoon\": {\"color\": \"spectrum\"}}, viewer=(0, 1))\n",
|
|
"\n",
|
|
"# Zoom and center the view\n",
|
|
"view.zoomTo()\n",
|
|
"view.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Generate a Protein with No Cysteines\n",
|
|
"\n",
|
|
"Guided generation is not constrained to structural metrics, you can also use it to guide the sequence generation.\n",
|
|
"\n",
|
|
"For example, we can create a `NoCysteineScoringFunction` that penalizes the protein if it contains Cysteine residues"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class NoCysteineScoringFunction(GuidedDecodingScoringFunction):\n",
|
|
" def __call__(self, protein: ESMProtein) -> float:\n",
|
|
" # Penalize proteins that contain cysteine\n",
|
|
" assert protein.sequence is not None, \"Protein must have a sequence to be scored\"\n",
|
|
" # Note that we use a negative score here, to discourage the presence of cysteine\n",
|
|
" return -protein.sequence.count(\"C\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"no_cysteine_guided_decoding = ESM3GuidedDecoding(\n",
|
|
" client=model, scoring_function=NoCysteineScoringFunction()\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"no_cysteine_protein = no_cysteine_guided_decoding.guided_generate(\n",
|
|
" protein=starting_protein,\n",
|
|
" num_decoding_steps=len(starting_protein) // 8,\n",
|
|
" num_samples_per_step=10,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Let's check our sequence!\n",
|
|
"\n",
|
|
"If guided generation converged to `score == 0.00`, the resulting protein should contain no Cysteine residues"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"assert no_cysteine_protein.sequence is not None, \"Protein must have a sequence\"\n",
|
|
"print(no_cysteine_protein.sequence)\n",
|
|
"print(f\"Number of cysteine residues: {no_cysteine_protein.sequence.count('C')}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Maximize Globularity\n",
|
|
"\n",
|
|
"We use the radius of gyration as a proxy to maximize globularity, we also encourage generations to have high pTM"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class RadiousOfGyrationScoringFunction(GuidedDecodingScoringFunction):\n",
|
|
" def __call__(self, protein: ESMProtein) -> float:\n",
|
|
" score = -1 * self.radius_of_gyration(protein)\n",
|
|
"\n",
|
|
" assert protein.ptm is not None, \"Protein must have pTM scores to be scored\"\n",
|
|
" if protein.ptm < 0.5:\n",
|
|
" # Penalize proteins with low pTM scores\n",
|
|
" score = score * 2\n",
|
|
"\n",
|
|
" return score\n",
|
|
"\n",
|
|
" @staticmethod\n",
|
|
" def radius_of_gyration(protein: ESMProtein) -> float:\n",
|
|
" protein_chain = protein.to_protein_chain()\n",
|
|
" arr = protein_chain.atom_array_no_insertions\n",
|
|
" return bs.gyration_radius(arr)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"radius_guided_decoding = ESM3GuidedDecoding(\n",
|
|
" client=model, scoring_function=RadiousOfGyrationScoringFunction()\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"radius_guided_protein = radius_guided_decoding.guided_generate(\n",
|
|
" protein=starting_protein,\n",
|
|
" num_decoding_steps=len(starting_protein) // 8,\n",
|
|
" num_samples_per_step=10,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"view = py3Dmol.view(width=800, height=400)\n",
|
|
"view.addModel(radius_guided_protein.to_pdb_string(), \"pdb\")\n",
|
|
"view.setStyle({\"cartoon\": {\"color\": \"spectrum\"}})\n",
|
|
"view.zoomTo()"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.11"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|