mirror of
https://github.com/evolutionaryscale/esm.git
synced 2026-06-04 17:14:23 +08:00
184 lines
6.6 KiB
Plaintext
184 lines
6.6 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "4l-TA3Od1JFs"
|
|
},
|
|
"source": [
|
|
"# ESM3 Inverse Folding Notebook\n",
|
|
"\n",
|
|
"This notebook is intended to be used as a tool for inverse folding using the ESM3 model.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"cellView": "form",
|
|
"id": "1TwEAW_LSNZZ"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# @title Input API keys, then hit `Runtime` -> `Run all`\n",
|
|
"# @markdown Our hosted service that provides access to the full suite of ESM3 models.\n",
|
|
"# @markdown To utilize the Forge API, users must first agree to the [Terms of Service](https://forge.evolutionaryscale.ai/termsofservice) and generate an access token via the [Forge console](https://forge.evolutionaryscale.ai/console).\n",
|
|
"# @markdown The console also provides a comprehensive list of models available to each user.\n",
|
|
"\n",
|
|
"import os\n",
|
|
"\n",
|
|
"# @markdown ### Authentication\n",
|
|
"# @markdown Paste your token from the [Forge console](https://forge.evolutionaryscale.ai/console)\n",
|
|
"forge_token = \"\" # @param {type:\"string\"}\n",
|
|
"os.environ[\"ESM_API_KEY\"] = forge_token\n",
|
|
"\n",
|
|
"# @markdown ### Model Selection\n",
|
|
"# @markdown Enter the model name from the [Forge console page](https://forge.evolutionaryscale.ai/console) that you would like to use:\n",
|
|
"model_name = \"esm3-medium-2024-08\" # @param {type:\"string\"}\n",
|
|
"\n",
|
|
"# @markdown ### Input Structure\n",
|
|
"pdb_code = \"\" # @param {type:\"string\"}\n",
|
|
"chain = \"detect\" # @param {type:\"string\"}\n",
|
|
"# @markdown Enter PDB code or leave blank to upload file\n",
|
|
"# @markdown Specify a chain if uploading a complex\n",
|
|
"\n",
|
|
"# @markdown ### Design Parameters\n",
|
|
"temperature = 0.1 # @param {type:\"slider\", min:0.0, max:1.0, step:0.01}\n",
|
|
"num_sequences = 8 # @param {type:\"integer\"}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"cellView": "form",
|
|
"id": "_942E63WS8-U"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# @title Install dependencies\n",
|
|
"import os\n",
|
|
"\n",
|
|
"os.system(\"pip install git+https://github.com/evolutionaryscale/esm\")\n",
|
|
"os.system(\n",
|
|
" \"pip install pydssp pygtrie dna-features-viewer py3dmol nest-asyncio ipywidgets\"\n",
|
|
")\n",
|
|
"\n",
|
|
"import nest_asyncio # noqa: E402\n",
|
|
"\n",
|
|
"nest_asyncio.apply()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"cellView": "form",
|
|
"id": "jXl61b-zTIsp"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# @title Run Inverse Folding\n",
|
|
"import numpy as np\n",
|
|
"from esm.sdk.api import ESMProtein, ESMProteinError, GenerationConfig\n",
|
|
"from esm.widgets.utils.clients import get_forge_client\n",
|
|
"from google.colab import files\n",
|
|
"from IPython.display import HTML\n",
|
|
"\n",
|
|
"\n",
|
|
"def get_pdb(pdb_code=\"\"):\n",
|
|
" if pdb_code is None or pdb_code == \"\":\n",
|
|
" upload_dict = files.upload()\n",
|
|
" pdb_string = upload_dict[list(upload_dict.keys())[0]]\n",
|
|
" with open(\"tmp.pdb\", \"wb\") as out:\n",
|
|
" out.write(pdb_string)\n",
|
|
" return \"tmp.pdb\"\n",
|
|
" else:\n",
|
|
" os.system(f\"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb\")\n",
|
|
" return f\"{pdb_code}.pdb\"\n",
|
|
"\n",
|
|
"\n",
|
|
"print(\"Loading structure...\")\n",
|
|
"pdb_path = get_pdb(pdb_code)\n",
|
|
"\n",
|
|
"# Create protein object\n",
|
|
"protein = ESMProtein.from_pdb(pdb_path, chain_id=chain)\n",
|
|
"protein.sequence = None\n",
|
|
"\n",
|
|
"print(\"Running inverse folding...\")\n",
|
|
"client = get_forge_client(model_name)\n",
|
|
"generations = client.batch_generate(\n",
|
|
" inputs=[protein] * num_sequences,\n",
|
|
" configs=[GenerationConfig(track=\"sequence\", temperature=temperature)]\n",
|
|
" * num_sequences,\n",
|
|
")\n",
|
|
"\n",
|
|
"if isinstance(protein, ESMProteinError):\n",
|
|
" raise RuntimeError(f\"Error: {str(protein)}\")\n",
|
|
"\n",
|
|
"errors: list[ESMProteinError] = []\n",
|
|
"sequences: list[str] = []\n",
|
|
"for i, protein in enumerate(generations):\n",
|
|
" if isinstance(protein, ESMProteinError):\n",
|
|
" errors.append((i, protein))\n",
|
|
" else:\n",
|
|
" sequences.append(protein.sequence)\n",
|
|
"\n",
|
|
"\n",
|
|
"def calculate_conservation_scores(sequences: list[str]) -> np.ndarray:\n",
|
|
" array = np.array([list(seq) for seq in sequences], dtype=\"S1\")\n",
|
|
" array = array.view(np.uint8) - ord(\"A\")\n",
|
|
"\n",
|
|
" # Create a 2D array of counts\n",
|
|
" max_range = 26\n",
|
|
" counts = np.zeros((max_range + 1, array.shape[1]), dtype=int)\n",
|
|
" for col in range(array.shape[1]):\n",
|
|
" count = np.bincount(array[:, col], minlength=max_range + 1)\n",
|
|
" counts[:, col] = count\n",
|
|
" counts = counts.T\n",
|
|
"\n",
|
|
" # Calculate entropy (-sum(p log p))\n",
|
|
" probabilities = counts / counts.sum(axis=1, keepdims=True)\n",
|
|
" entropy = -np.sum(probabilities * np.log(probabilities + 1e-9), axis=1)\n",
|
|
"\n",
|
|
" # Convert to conservation score (1 - normalized entropy)\n",
|
|
" max_entropy = np.log(256)\n",
|
|
" # Magic constant to make displaying non-conserved residues more apparent\n",
|
|
" conservation_scores = np.maximum(0, 0.5 - (entropy / max_entropy)) / 0.5\n",
|
|
"\n",
|
|
" return conservation_scores\n",
|
|
"\n",
|
|
"\n",
|
|
"def display_sequences(sequences: list[str]):\n",
|
|
" conservation_scores = calculate_conservation_scores(sequences)\n",
|
|
" html_output = '<pre style=\"line-height:1.0;letter-spacing:3px;font-family:monospace;margin:0;padding:0\">'\n",
|
|
" for sequence in sequences:\n",
|
|
" for j, residue in enumerate(sequence):\n",
|
|
" # Add padding for alignment and color the background\n",
|
|
" html_output += f'<span style=\"background-color: rgba(9, 121, 105,{conservation_scores[j]})\">{residue}</span>'\n",
|
|
" html_output += \"<br>\"\n",
|
|
" html_output += \"</pre>\"\n",
|
|
" display(HTML(html_output))\n",
|
|
"\n",
|
|
"\n",
|
|
"display_sequences(sequences)\n",
|
|
"\n",
|
|
"for i, error in errors:\n",
|
|
" print(f\"Error code {error.error_code} at index {i}: {error.error_msg}\")"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"colab": {
|
|
"provenance": []
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0
|
|
}
|