mirror of
https://github.com/evolutionaryscale/esm.git
synced 2026-06-04 17:14:23 +08:00
445 lines
14 KiB
Plaintext
445 lines
14 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# [Tutorial](https://github.com/biohub/esm/tree/main/cookbook/tutorials): Guided Generation with ESM3\n",
|
|
"\n",
|
|
"Guided generation is a powerful tool that allows you to sample outputs out of ESM3 that maximize any kind of score function.\n",
|
|
"\n",
|
|
"For example, you may want to\n",
|
|
"1. Guide generations towards higher quality metrics like pTM\n",
|
|
"2. Constrain the distribution of outputs to have certain amino acid frequencies or structural attributes\n",
|
|
"3. Minimize a biophysical energy function\n",
|
|
"4. Use experimental screening data to guide designs with a regression model\n",
|
|
"\n",
|
|
"As long as your scoring function takes a protein as input and outputs a single score, you can use it to guide designs. To accomplish this, we use an implementation of derivative-free guidance inspired by Soft Value-Based Decoding described in [Li, et al 2024](https://arxiv.org/abs/2408.08252) and constrained optimization using the Modified Differential Method of Multipliers from [Platt & Barr 1987](https://proceedings.neurips.cc/paper_files/paper/1987/file/a1126573153ad7e9f44ba80e99316482-Paper.pdf)\n",
|
|
"\n",
|
|
"In this notebook we will walk through a few examples to illustrate how to use guided generation. \n",
|
|
"\n",
|
|
"1. Guide towards high pTM for improved generation quality\n",
|
|
"2. Generate a protein with no cysteine (C) residues\n",
|
|
"3. Maximize protein globularity by minimizing the radius of gyration, while keeping pTM high\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Imports"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# # If you are working in colab, uncomment these lines to install dependencies\n",
|
|
"# !pip install esm@git+https://github.com/Biohub/esm.git@c94ed8d\n",
|
|
"# !pip install py3dmol"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import biotite.structure as bs\n",
|
|
"import py3Dmol\n",
|
|
"\n",
|
|
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
|
|
"from esm.sdk.experimental import ESM3GuidedDecoding, GuidedDecodingScoringFunction"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Creating a scoring function\n",
|
|
"\n",
|
|
"To get started with the guided generation API the only thing you need is to create a callable class that inherits from `GuidedDecodingScoringFunction`. This class should receive as input an `ESMProtein` object and output a numerical score.\n",
|
|
"\n",
|
|
"\n",
|
|
"For example, one of the computational metrics we can use to measure the quality of a generated protein structure is the Predicted Template Modelling (pTM) score, so we'll use it to create a `PTMScoringFunction`.\n",
|
|
"\n",
|
|
"Fortunately for us, every time we generate a protein using ESM3 (either locally or on Biohub) we also get its pTM, so all our class needs to do when its called is to return the `ptm` attribute of its input."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Create scoring function (e.g. PTM scoring function)\n",
|
|
"class PTMScoringFunction(GuidedDecodingScoringFunction):\n",
|
|
" def __call__(self, protein: ESMProtein) -> float:\n",
|
|
" # Minimal example of a scoring function that scores proteins based on their pTM score\n",
|
|
" # Given that ESM3 already has a pTM prediction head, we can directly access the pTM score\n",
|
|
" assert protein.ptm is not None, \"Protein must have pTM scores to be scored\"\n",
|
|
" return float(protein.ptm)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Initialize your client\n",
|
|
"\n",
|
|
"The guided generation is compatible with both local inference using the `ESM3` class and remote inference with the Biohub client"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# To use the tokenizers and the open model you'll need to login into Hugging Face\n",
|
|
"\n",
|
|
"from huggingface_hub import notebook_login\n",
|
|
"\n",
|
|
"notebook_login()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"## Locally with ESM3-open\n",
|
|
"# from esm.models.esm3 import ESM3\n",
|
|
"# model = ESM3.from_pretrained().to(\"cuda\")\n",
|
|
"\n",
|
|
"## On Biohub with larger ESM3 models\n",
|
|
"from getpass import getpass\n",
|
|
"\n",
|
|
"from esm.sdk import client\n",
|
|
"\n",
|
|
"token = getpass(\"Token from Biohub: \")\n",
|
|
"model = client(model=\"esm3-medium-2024-08\", url=\"https://biohub.ai\", token=token)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Guide towards high pTM for improved generation quality\n",
|
|
"\n",
|
|
"Once your scoring function is defined and you have initialized your model you can create an `ESM3GuidedDecoding` instance to sample from it"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"ptm_guided_decoding = ESM3GuidedDecoding(\n",
|
|
" client=model, scoring_function=PTMScoringFunction()\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Start from a fully masked protein\n",
|
|
"PROTEIN_LENGTH = 256\n",
|
|
"starting_protein = ESMProtein(sequence=\"_\" * PROTEIN_LENGTH)\n",
|
|
"\n",
|
|
"# Call guided_generate\n",
|
|
"generated_protein = ptm_guided_decoding.guided_generate(\n",
|
|
" protein=starting_protein,\n",
|
|
" num_decoding_steps=len(starting_protein) // 8,\n",
|
|
" num_samples_per_step=10,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Compare against baseline with no guidance\n",
|
|
"\n",
|
|
"First we are going to sample a protein generated without any guidance. This means that, when not providing pTM guidance, we could be sampling proteins that have no clear structure."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Generate a protein WITHOUT guidance\n",
|
|
"generated_protein_no_guided: ESMProtein = model.generate(\n",
|
|
" input=starting_protein,\n",
|
|
" config=GenerationConfig(track=\"sequence\", num_steps=len(starting_protein) // 8),\n",
|
|
") # type: ignore\n",
|
|
"\n",
|
|
"# Fold\n",
|
|
"generated_protein_no_guided: ESMProtein = model.generate(\n",
|
|
" input=generated_protein_no_guided,\n",
|
|
" config=GenerationConfig(track=\"structure\", num_steps=1),\n",
|
|
") # type: ignore"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Create a 1x2 grid of viewers (1 row, 2 columns)\n",
|
|
"view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))\n",
|
|
"\n",
|
|
"# Convert ESMProtein objects to ProteinChain objects\n",
|
|
"protein_chain1 = generated_protein_no_guided.to_protein_chain()\n",
|
|
"protein_chain2 = generated_protein.to_protein_chain()\n",
|
|
"\n",
|
|
"# Add models to respective panels\n",
|
|
"view.addModel(protein_chain1.to_pdb_string(), \"pdb\", viewer=(0, 0))\n",
|
|
"view.addModel(protein_chain2.to_pdb_string(), \"pdb\", viewer=(0, 1))\n",
|
|
"\n",
|
|
"# Set styles for each protein\n",
|
|
"view.setStyle({}, {\"cartoon\": {\"color\": \"spectrum\"}}, viewer=(0, 0))\n",
|
|
"view.setStyle({}, {\"cartoon\": {\"color\": \"spectrum\"}}, viewer=(0, 1))\n",
|
|
"\n",
|
|
"# Zoom and center the view\n",
|
|
"view.zoomTo()\n",
|
|
"view.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(f\"pTM Without guidance: {generated_protein_no_guided.ptm:.3f}\")\n",
|
|
"print(f\"pTM With guidance: {generated_protein.ptm:.3f}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Generate a Protein with No Cysteines\n",
|
|
"\n",
|
|
"Guided generation is not constrained to structural metrics, you can also use it to guide the sequence generation.\n",
|
|
"\n",
|
|
"For example, we can create a `NoCysteineScoringFunction` that penalizes the protein if it contains Cysteine residues"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class NoCysteineScoringFunction(GuidedDecodingScoringFunction):\n",
|
|
" def __call__(self, protein: ESMProtein) -> float:\n",
|
|
" # Penalize proteins that contain cysteine\n",
|
|
" assert protein.sequence is not None, \"Protein must have a sequence to be scored\"\n",
|
|
" # Note that we use a negative score here, to discourage the presence of cysteine\n",
|
|
" return -protein.sequence.count(\"C\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"no_cysteine_guided_decoding = ESM3GuidedDecoding(\n",
|
|
" client=model, scoring_function=NoCysteineScoringFunction()\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Start from a fully masked protein\n",
|
|
"PROTEIN_LENGTH = 256\n",
|
|
"starting_protein = ESMProtein(sequence=\"_\" * PROTEIN_LENGTH)\n",
|
|
"\n",
|
|
"# Call guided_generate\n",
|
|
"no_cysteine_protein = no_cysteine_guided_decoding.guided_generate(\n",
|
|
" protein=starting_protein,\n",
|
|
" num_decoding_steps=len(starting_protein) // 8,\n",
|
|
" num_samples_per_step=10,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Let's check our sequence!\n",
|
|
"\n",
|
|
"If guided generation converged to `score == 0.00`, the resulting protein should contain no Cysteine residues"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"assert no_cysteine_protein.sequence is not None, \"Protein must have a sequence\"\n",
|
|
"print(no_cysteine_protein.sequence)\n",
|
|
"print(f\"Number of cysteine residues: {no_cysteine_protein.sequence.count('C')}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Maximize Globularity\n",
|
|
"\n",
|
|
"We use the radius of gyration as a proxy to maximize globularity, and we will also encourage generations to have high pTM by using constraints"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from esm.sdk.experimental import (\n",
|
|
" ConstraintType,\n",
|
|
" ESM3GuidedDecodingWithConstraints,\n",
|
|
" GenerationConstraint,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class RadiousOfGyrationScoringFunction(GuidedDecodingScoringFunction):\n",
|
|
" def __call__(self, protein: ESMProtein) -> float:\n",
|
|
" # Use the negative radius of gyration as the score to maximize\n",
|
|
" score = -1 * self.radius_of_gyration(protein)\n",
|
|
"\n",
|
|
" # Re-scale the score to be in a similar magnitude as pTM\n",
|
|
" score = score / 100.0\n",
|
|
"\n",
|
|
" return score\n",
|
|
"\n",
|
|
" @staticmethod\n",
|
|
" def radius_of_gyration(protein: ESMProtein) -> float:\n",
|
|
" protein_chain = protein.to_protein_chain()\n",
|
|
" arr = protein_chain.atom_array_no_insertions\n",
|
|
" return bs.gyration_radius(arr)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Constrain generation to have pTM > 0.75\n",
|
|
"ptm_constraint = GenerationConstraint(\n",
|
|
" scoring_function=PTMScoringFunction(),\n",
|
|
" constraint_type=ConstraintType.GREATER_EQUAL,\n",
|
|
" value=0.75,\n",
|
|
")\n",
|
|
"\n",
|
|
"radius_guided_decoding = ESM3GuidedDecodingWithConstraints(\n",
|
|
" client=model,\n",
|
|
" scoring_function=RadiousOfGyrationScoringFunction(),\n",
|
|
" constraints=[ptm_constraint], # Add list of constraints\n",
|
|
" damping=1.0, # Damping factor for the MMDM algorithm\n",
|
|
" learning_rate=10.0, # Learning rate for the MMDM algorithm\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Start from a fully masked protein\n",
|
|
"PROTEIN_LENGTH = 256\n",
|
|
"starting_protein = ESMProtein(sequence=\"_\" * PROTEIN_LENGTH)\n",
|
|
"\n",
|
|
"# Call guided_generate\n",
|
|
"radius_guided_protein = radius_guided_decoding.guided_generate(\n",
|
|
" protein=starting_protein,\n",
|
|
" num_decoding_steps=len(starting_protein) // 8,\n",
|
|
" num_samples_per_step=10,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Visualize the trajectory of the constrained generation\n",
|
|
"radius_guided_decoding.visualize_latest_trajectory()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Visualize the generated protein\n",
|
|
"view = py3Dmol.view(width=800, height=400)\n",
|
|
"view.addModel(radius_guided_protein.to_pdb_string(), \"pdb\")\n",
|
|
"view.setStyle({\"cartoon\": {\"color\": \"spectrum\"}})\n",
|
|
"view.zoomTo()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Check pTM\n",
|
|
"radius_guided_protein.ptm"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.11"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|