mirror of
https://github.com/evolutionaryscale/esm.git
synced 2026-06-04 17:14:23 +08:00
677 lines
26 KiB
Plaintext
677 lines
26 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# ESM3\n",
|
|
"\n",
|
|
"ESM3 is a frontier generative model for biology, able to jointly reason across three fundamental biological properties of proteins: sequence, structure, and function. These three data modalities are represented as tracks of discrete tokens at the input and output of ESM3. You can present the model with a combination of partial inputs across the tracks, and ESM3 will provide output predictions for all the tracks.\n",
|
|
"\n",
|
|
"ESM3 is a generative masked language model. You can prompt it with partial sequence, structure, and function keywords, and iteratively sample masked positions until all positions are unmasked. This iterative sampling is what the `.generate()` function does.\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"The ESM3 architecture is highly scalable due to its transformer backbone and all-to-all reasoning over discrete token sequences. At its largest scale, ESM3 was trained with 1.07e24 FLOPs on 2.78 billion proteins and 771 billion unique tokens, and has 98 billion parameters.\n",
|
|
"Here we present `esm3-open-small`. With 1.4B parameters it is the smallest and fastest model in the family, trained specifically to be open sourced. ESM3-open is available under a non-commercial license.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Imports\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"%set_env TOKENIZERS_PARALLELISM=false\n",
|
|
"!pip install esm\n",
|
|
"import numpy as np\n",
|
|
"import torch\n",
|
|
"\n",
|
|
"!pip install py3Dmol\n",
|
|
"import py3Dmol\n",
|
|
"from esm.sdk import client\n",
|
|
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
|
|
"from esm.utils.structure.protein_chain import ProteinChain"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Set up the client to Forge\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from getpass import getpass\n",
|
|
"\n",
|
|
"token = getpass(\"Token from Forge console: \")\n",
|
|
"model = client(model=\"esm3-open\", url=\"https://forge.evolutionaryscale.ai\", token=token)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Let's construct a prompt for ESM3, focusing on the task of scaffolding a motif from a natural protein\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"First, we can use the `ProteinChain` class from the `esm` sdk to grab a protein structure from the PDB.\n",
|
|
"We'll work with a human renal (kidney) dipeptidase (a protein that breaks up two amino acids bound together). Renal dipeptidases are of particular interest because they metabolize certain antibiotics.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"pdb_id = \"1ITU\" # PDB ID corresponding to Renal Dipeptidase\n",
|
|
"chain_id = \"A\" # Chain ID corresponding to Renal Dipeptidase in the PDB structure\n",
|
|
"renal_dipep_chain = ProteinChain.from_rcsb(pdb_id, chain_id)\n",
|
|
"# Alternatively, we could have used ProteinChain.from_pdb() to load a protein structure from a local PDB file"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The `ProteinChain` class is a object that makes it easy to work with protein structures. It contains a `sequence` attribute that contains the amino acid sequence of the protein\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(renal_dipep_chain.sequence)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"`ProteinChain` also contains an `atom37_positions` numpy array that contains the atomic coordinates of each of the residues in the protein.\n",
|
|
"\n",
|
|
"The shape of the array is `(n_residues, 37, 3)` where `n_residues` is the number of residues in the protein and 37 is the number of possible distinct atoms that may be present across all amino acids (e.g. the first three atoms are the N, C-alpha, and C atoms corresponding to the protein backbone). The 3 corresponds to the x, y, and z coordinates of each atom. The atom37 representation of protein structure allows us to use a single format to conveniently represent all amino acids -- **coordinates are only present for the atoms that are present in the amino acid and `nan` otherwise**.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(\"atom37_positions shape: \", renal_dipep_chain.atom37_positions.shape)\n",
|
|
"print(renal_dipep_chain.atom37_positions[:3])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can visualize the protein chain using the `py3Dmol` library\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# First we can create a `py3Dmol` view object\n",
|
|
"view = py3Dmol.view(width=500, height=500)\n",
|
|
"# py3Dmol requires the atomic coordinates to be in PDB format, so we convert the `ProteinChain` object to a PDB string\n",
|
|
"pdb_str = renal_dipep_chain.to_pdb_string()\n",
|
|
"# Load the PDB string into the `py3Dmol` view object\n",
|
|
"view.addModel(pdb_str, \"pdb\")\n",
|
|
"# Set the style of the protein chain\n",
|
|
"view.setStyle({\"cartoon\": {\"color\": \"spectrum\"}})\n",
|
|
"# Zoom in on the protein chain\n",
|
|
"view.zoomTo()\n",
|
|
"# Display the protein chain\n",
|
|
"view.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Now, let's try to scaffold a motif from this protein using ESM3 -- we'll prompt the model with the sequence and structure of a helix-coil motif from renal dipeptidase and have the model generate a larger scaffold that includes the motif\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"motif_inds = np.arange(123, 146)\n",
|
|
"# `ProteinChain` objects can be indexed like numpy arrays to extract the sequence and atomic coordinates of a subset of residues\n",
|
|
"motif_sequence = renal_dipep_chain[motif_inds].sequence\n",
|
|
"motif_atom37_positions = renal_dipep_chain[motif_inds].atom37_positions\n",
|
|
"print(\"Motif sequence: \", motif_sequence)\n",
|
|
"print(\"Motif atom37_positions shape: \", motif_atom37_positions.shape)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can also visualize the motif in the original chain using `py3Dmol`. We'll color the original chain in grey and the motif in blue\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"view = py3Dmol.view(width=500, height=500)\n",
|
|
"view.addModel(pdb_str, \"pdb\")\n",
|
|
"view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n",
|
|
"motif_res_inds = (\n",
|
|
" motif_inds + 1\n",
|
|
").tolist() # residue indices are 1-indexed in PDB files, so we add 1 to the indices\n",
|
|
"view.addStyle({\"resi\": motif_res_inds}, {\"cartoon\": {\"color\": \"cyan\"}})\n",
|
|
"view.zoomTo()\n",
|
|
"view.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Now, we can use the `ESMProtein` class to construct a prompt that will instruct ESM3 to scaffold the motif\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"prompt_length = 200\n",
|
|
"# First, we can construct a sequence prompt of all masks\n",
|
|
"sequence_prompt = [\"_\"] * prompt_length\n",
|
|
"# Then, we can randomly insert the motif sequence into the prompt (we randomly choose 72 here)\n",
|
|
"sequence_prompt[72 : 72 + len(motif_sequence)] = list(motif_sequence)\n",
|
|
"sequence_prompt = \"\".join(sequence_prompt)\n",
|
|
"print(\"Sequence prompt: \", sequence_prompt)\n",
|
|
"print(\"Length of sequence prompt: \", len(sequence_prompt))\n",
|
|
"\n",
|
|
"# Next, we can construct a structure prompt of all nan coordinates\n",
|
|
"structure_prompt = torch.full((prompt_length, 37, 3), np.nan)\n",
|
|
"# Then, we can insert the motif atomic coordinates into the prompt, starting at index 72\n",
|
|
"structure_prompt[72 : 72 + len(motif_atom37_positions)] = torch.tensor(\n",
|
|
" motif_atom37_positions\n",
|
|
")\n",
|
|
"print(\"Structure prompt shape: \", structure_prompt.shape)\n",
|
|
"print(\n",
|
|
" \"Indices with structure conditioning: \",\n",
|
|
" torch.where(~torch.isnan(structure_prompt).any(dim=-1).all(dim=-1))[0].tolist(),\n",
|
|
")\n",
|
|
"\n",
|
|
"# Finally, we can use the ESMProtein class to compose the sequence and structure prompts into a single prompt that can be passed to ESM3\n",
|
|
"protein_prompt = ESMProtein(sequence=sequence_prompt, coordinates=structure_prompt)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Now, we can use the `generate` method of the model to iteratively sample a protein sequence based on the prompt. Under the hood, the model performs num_steps forward passes and samples a set of tokens at each step until the chosen track being generated is fully unmasked.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# We'll have to first construct a `GenerationConfig` object that specifies the decoding parameters that we want to use\n",
|
|
"sequence_generation_config = GenerationConfig(\n",
|
|
" track=\"sequence\", # We want ESM3 to generate tokens for the sequence track\n",
|
|
" num_steps=sequence_prompt.count(\"_\")\n",
|
|
" // 2, # We'll use num(mask tokens) // 2 steps to decode the sequence\n",
|
|
" temperature=0.5, # We'll use a temperature of 0.5 to control the randomness of the decoding process\n",
|
|
")\n",
|
|
"\n",
|
|
"# Now, we can use the `generate` method of the model to decode the sequence\n",
|
|
"sequence_generation = model.generate(protein_prompt, sequence_generation_config)\n",
|
|
"print(\"Sequence Prompt:\\n\\t\", protein_prompt.sequence)\n",
|
|
"print(\"Generated sequence:\\n\\t\", sequence_generation.sequence)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can also use the `generate` method to predict the structure of the generated sequence by iteratively sampling structure tokens.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"structure_prediction_config = GenerationConfig(\n",
|
|
" track=\"structure\", # We want ESM3 to generate tokens for the structure track\n",
|
|
" num_steps=len(sequence_generation) // 8,\n",
|
|
" temperature=0.7,\n",
|
|
")\n",
|
|
"structure_prediction_prompt = ESMProtein(sequence=sequence_generation.sequence)\n",
|
|
"structure_prediction = model.generate(\n",
|
|
" structure_prediction_prompt, structure_prediction_config\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Now, we can visualize the generated structure using `py3Dmol`. We'll visualize the generated structure (right, green) alongside the original structure (left, grey) from which the motif was drawn. The motif residues are colored in cyan.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Convert the generated structure to a back into a ProteinChain object\n",
|
|
"structure_prediction_chain = structure_prediction.to_protein_chain()\n",
|
|
"# Align the generated structure to the original structure using the motif residues\n",
|
|
"motif_inds_in_generation = np.arange(72, 72 + len(motif_sequence))\n",
|
|
"structure_prediction_chain.align(\n",
|
|
" renal_dipep_chain, mobile_inds=motif_inds_in_generation, target_inds=motif_inds\n",
|
|
")\n",
|
|
"crmsd = structure_prediction_chain.rmsd(\n",
|
|
" renal_dipep_chain, mobile_inds=motif_inds_in_generation, target_inds=motif_inds\n",
|
|
")\n",
|
|
"print(\n",
|
|
" \"cRMSD of the motif in the generated structure vs the original structure: \", crmsd\n",
|
|
")\n",
|
|
"\n",
|
|
"view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))\n",
|
|
"view.addModel(pdb_str, \"pdb\", viewer=(0, 0))\n",
|
|
"view.addModel(structure_prediction_chain.to_pdb_string(), \"pdb\", viewer=(0, 1))\n",
|
|
"view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}}, viewer=(0, 0))\n",
|
|
"view.setStyle({\"cartoon\": {\"color\": \"lightgreen\"}}, viewer=(0, 1))\n",
|
|
"view.addStyle({\"resi\": motif_res_inds}, {\"cartoon\": {\"color\": \"cyan\"}}, viewer=(0, 0))\n",
|
|
"view.addStyle(\n",
|
|
" {\"resi\": (motif_inds_in_generation + 1).tolist()},\n",
|
|
" {\"cartoon\": {\"color\": \"cyan\"}},\n",
|
|
" viewer=(0, 1),\n",
|
|
")\n",
|
|
"view.zoomTo()\n",
|
|
"view.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Secondary Structure Editing Example: Helix Shortening\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Now, we can try another generation task with ESM3. We'll use the secondary structure track, along with the sequence track, to shorten a helix-coil-helix region (residues 39-111) in a protein structure (colored in blue below)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"helix_shortening_chain = ProteinChain.from_rcsb(\"7XBQ\", \"A\")\n",
|
|
"view = py3Dmol.view(width=500, height=500)\n",
|
|
"view.addModel(helix_shortening_chain.to_pdb_string(), \"pdb\")\n",
|
|
"view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n",
|
|
"helix_region = np.arange(38, 111) # zero-indexed\n",
|
|
"view.addStyle(\n",
|
|
" {\"resi\": (helix_region + 1).tolist()}, {\"cartoon\": {\"color\": \"lightblue\"}}\n",
|
|
")\n",
|
|
"view.zoomTo()\n",
|
|
"view.show()\n",
|
|
"helix_shortening_ss8 = \"CCCSHHHHHHHHHHHTTCHHHHHHHHHHHHHTCSSCCCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHTTCHHHHHHHHHHHHHHHHHHHHHHHHHHHHIIIIIGGGCCSHHHHHHHHHHHHHHHHHHHHHCCHHHHHHHHHHHHHHHHHHHHHHHHHSCTTCHHHHHHHHHHHHHIIIIICCHHHHHHHHHHHHHHHHTTCTTCCSSHHHHHHHHHHHHHHHHHHHC\"\n",
|
|
"print(\n",
|
|
" \"Secondary structure of protein: (H: Alpha Helix, E: Beta Strand, C: Coil) \\n\\t\",\n",
|
|
" helix_shortening_ss8,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The helix-coil-helix region in the original protein is 73 residues long. We will try to shorten it to 45 residues by prompting the model with partial sequence and secondary structure\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"shortened_region_length = 45\n",
|
|
"\n",
|
|
"# We'll construct a sequence prompt that masks the (shortened) helix-coil-helix region, but leaves the flanking regions unmasked\n",
|
|
"sequence_prompt = (\n",
|
|
" helix_shortening_chain.sequence[: helix_region[0]]\n",
|
|
" + \"_\" * shortened_region_length\n",
|
|
" + helix_shortening_chain.sequence[helix_region[-1] + 1 :]\n",
|
|
")\n",
|
|
"print(\"Sequence prompt:\\n\\t\", sequence_prompt)\n",
|
|
"\n",
|
|
"# We'll construct a secondary structure prompt that retains the secondary structure of the flanking regions, and shortens the lengths of helices in the helix-coil-helix region\n",
|
|
"ss8_prompt = (\n",
|
|
" helix_shortening_ss8[: helix_region[0]]\n",
|
|
" + (\n",
|
|
" ((shortened_region_length - 3) // 2) * \"H\"\n",
|
|
" + \"C\" * 3\n",
|
|
" + ((shortened_region_length - 3) // 2) * \"H\"\n",
|
|
" )\n",
|
|
" + helix_shortening_ss8[helix_region[-1] + 1 :]\n",
|
|
")\n",
|
|
"print(\"SS8 prompt:\\n\\t\", ss8_prompt)\n",
|
|
"print(\n",
|
|
" \"Proposed SS8 for shortened helix-coil-helix region:\\n\\t\",\n",
|
|
" \" \" * helix_region[0] + ss8_prompt[helix_region[0] : helix_region[0] + 45],\n",
|
|
")\n",
|
|
"\n",
|
|
"print(\"\")\n",
|
|
"print(\"Original sequence:\\n\\t\", helix_shortening_chain.sequence)\n",
|
|
"print(\"Original SS8:\\n\\t\", helix_shortening_ss8)\n",
|
|
"print(\n",
|
|
" \"Original SS8 for helix-coil-helix region:\\n\\t\",\n",
|
|
" \" \" * helix_region[0]\n",
|
|
" + helix_shortening_ss8[helix_region[0] : helix_region[-1] + 1],\n",
|
|
")\n",
|
|
"\n",
|
|
"\n",
|
|
"# We can again use the ESMProtein class to compose the sequence and secondary structure prompts into a single prompt that can be passed to ESM3\n",
|
|
"protein_prompt = ESMProtein(sequence=sequence_prompt, secondary_structure=ss8_prompt)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can again use the `generate` method of the model to iteratively decode a protein sequence based on the prompt\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(\"Generating protein sequence...\")\n",
|
|
"sequence_generation = model.generate(\n",
|
|
" protein_prompt,\n",
|
|
" GenerationConfig(\n",
|
|
" track=\"sequence\",\n",
|
|
" num_steps=protein_prompt.sequence.count(\"_\") // 2,\n",
|
|
" temperature=0.5,\n",
|
|
" ),\n",
|
|
")\n",
|
|
"print(\"Folding protein...\")\n",
|
|
"structure_prediction = model.generate(\n",
|
|
" ESMProtein(sequence=sequence_generation.sequence),\n",
|
|
" GenerationConfig(\n",
|
|
" track=\"structure\", num_steps=len(protein_prompt) // 4, temperature=0\n",
|
|
" ),\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Now, we can visualize the generated structure using `py3Dmol`. We'll visualize the generated structure (right) alongside the original structure (left) from which the motif was drawn. The helix-coil-helix region in the original structure is colored in blue and the shortened region in the generated structure is colored in red.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"predicted_chain = structure_prediction.to_protein_chain()\n",
|
|
"predicted_chain = predicted_chain.align(\n",
|
|
" helix_shortening_chain,\n",
|
|
" mobile_inds=np.arange(len(predicted_chain) - 120, len(predicted_chain)),\n",
|
|
" target_inds=np.arange(\n",
|
|
" len(helix_shortening_chain) - 120, len(helix_shortening_chain)\n",
|
|
" ),\n",
|
|
")\n",
|
|
"view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))\n",
|
|
"view.addModel(helix_shortening_chain.to_pdb_string(), \"pdb\", viewer=(0, 0))\n",
|
|
"view.addModel(predicted_chain.to_pdb_string(), \"pdb\", viewer=(0, 1))\n",
|
|
"view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n",
|
|
"view.addStyle(\n",
|
|
" {\"resi\": (helix_region + 1).tolist()},\n",
|
|
" {\"cartoon\": {\"color\": \"lightblue\"}},\n",
|
|
" viewer=(0, 0),\n",
|
|
")\n",
|
|
"view.addStyle(\n",
|
|
" {\"resi\": (np.arange(helix_region[0], helix_region[0] + 45) + 1).tolist()},\n",
|
|
" {\"cartoon\": {\"color\": \"red\"}},\n",
|
|
" viewer=(0, 1),\n",
|
|
")\n",
|
|
"view.zoomTo()\n",
|
|
"view.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# SASA Editing Example: Exposing a buried helix\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Let's grab 1LBS from the PDB and visualize it using `py3Dmol`. 1LBS has an alternating alpha-beta sandwich fold, with a buried helix in the center, highlighted in red\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"lipase_chain = ProteinChain.from_rcsb(\"1LBS\", \"A\")\n",
|
|
"span_start = 105\n",
|
|
"span_end = 116\n",
|
|
"view = py3Dmol.view(width=500, height=500)\n",
|
|
"view.addModel(lipase_chain.to_pdb_string(), \"pdb\")\n",
|
|
"view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n",
|
|
"view.addStyle(\n",
|
|
" {\"resi\": (np.arange(span_start, span_end) + 1).tolist()},\n",
|
|
" {\"cartoon\": {\"color\": \"red\"}},\n",
|
|
")\n",
|
|
"view.zoomTo()\n",
|
|
"view.show()\n",
|
|
"lipase_ss8 = \"CCSSCCCCSSCHHHHHHTEEETTBBTTBCSSEEEEECCTTCCHHHHHTTTHHHHHHHTTCEEEEECCTTTTCSCHHHHHHHHHHHHHHHHHHTTSCCEEEEEETHHHHHHHHHHHHCGGGGGTEEEEEEESCCTTCBGGGHHHHHTTCBCHHHHHTBTTCHHHHHHHHTTTTBCSSCEEEEECTTCSSSCCCCSSSTTSTTCCBTSEEEEHHHHHCTTCCCCSHHHHHBHHHHHHHHHHHHCTTSSCCGGGCCSTTCCCSBCTTSCHHHHHHHHSTHHHHHHHHHHSCCBSSCCCCCGGGGGGSTTCEETTEECCC\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"We can construct a multimodal prompt for ESM3 to instruct it to expose the buried helix as follows:\n",
|
|
"\n",
|
|
"1. Prompt with the **structure** of the buried helix highlighted in red -- this will prompt ESM3 to generate a protein that contains that same helix\n",
|
|
"2. Prompt with high **SASA** values for the residues in the buried helix -- this will prompt ESM3 to expose the helix to the surface of the protein\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"structure_prompt = torch.full((len(lipase_chain), 37, 3), torch.nan)\n",
|
|
"structure_prompt[span_start:span_end] = torch.tensor(\n",
|
|
" lipase_chain[span_start:span_end].atom37_positions, dtype=torch.float32\n",
|
|
")\n",
|
|
"\n",
|
|
"sasa_prompt = [None] * len(lipase_chain)\n",
|
|
"sasa_prompt[span_start:span_end] = [40.0] * (span_end - span_start)\n",
|
|
"\n",
|
|
"print(\"SASA prompt (just for buried region): \", sasa_prompt[span_start:span_end])\n",
|
|
"\n",
|
|
"protein_prompt = ESMProtein(\n",
|
|
" sequence=\"_\" * len(lipase_chain), coordinates=structure_prompt, sasa=sasa_prompt\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"This is a more difficult task, so you may need to sample more generations from ESM before you find a solution. We'll sample 16 here and sort by the generations with the highest predicted TM-score (pTM) by ESM3.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import concurrent.futures\n",
|
|
"\n",
|
|
"\n",
|
|
"def generate_protein_sequence_and_structure(protein_prompt, model):\n",
|
|
" sequence_generation = model.generate(\n",
|
|
" protein_prompt,\n",
|
|
" GenerationConfig(\n",
|
|
" track=\"sequence\",\n",
|
|
" num_steps=protein_prompt.sequence.count(\"_\") // 2,\n",
|
|
" temperature=0.5,\n",
|
|
" ),\n",
|
|
" )\n",
|
|
" structure_prediction = model.generate(\n",
|
|
" ESMProtein(sequence=sequence_generation.sequence),\n",
|
|
" GenerationConfig(\n",
|
|
" track=\"structure\", num_steps=len(protein_prompt) // 4, temperature=0.7\n",
|
|
" ),\n",
|
|
" )\n",
|
|
" return structure_prediction\n",
|
|
"\n",
|
|
"\n",
|
|
"N_SAMPLES = 16\n",
|
|
"with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:\n",
|
|
" futures = [\n",
|
|
" executor.submit(generate_protein_sequence_and_structure, protein_prompt, model)\n",
|
|
" for _ in range(N_SAMPLES)\n",
|
|
" ]\n",
|
|
"\n",
|
|
" generated_proteins = [future.result() for future in futures]\n",
|
|
"\n",
|
|
"\n",
|
|
"# Sort generations by ptm\n",
|
|
"generated_proteins = sorted(\n",
|
|
" generated_proteins, key=lambda x: x.ptm.item(), reverse=True\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Let's visualize the top 4 generations by pTM, alongside with the original protein (on the left)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"N_SAMPLES_TO_SHOW = 4\n",
|
|
"view = py3Dmol.view(width=1000, height=500, viewergrid=(1, N_SAMPLES_TO_SHOW + 1))\n",
|
|
"view.addModel(lipase_chain.to_pdb_string(), \"pdb\", viewer=(0, 0))\n",
|
|
"for i in range(N_SAMPLES_TO_SHOW):\n",
|
|
" print(\n",
|
|
" \"PTM of generated protein {}: {:.2f}\".format(\n",
|
|
" i + 1, generated_proteins[i].ptm.item()\n",
|
|
" )\n",
|
|
" )\n",
|
|
" view.addModel(\n",
|
|
" generated_proteins[i].to_protein_chain().to_pdb_string(),\n",
|
|
" \"pdb\",\n",
|
|
" viewer=(0, i + 1),\n",
|
|
" )\n",
|
|
"view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n",
|
|
"view.addStyle(\n",
|
|
" {\"resi\": (np.arange(span_start, span_end) + 1).tolist()},\n",
|
|
" {\"cartoon\": {\"color\": \"red\"}},\n",
|
|
")\n",
|
|
"view.zoomTo()\n",
|
|
"view.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.15"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|