Files
esm/tools/invfold.ipynb
2025-01-16 10:34:33 -08:00

184 lines
6.6 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "4l-TA3Od1JFs"
},
"source": [
"# ESM3 Inverse Folding Notebook\n",
"\n",
"This notebook is intended to be used as a tool for inverse folding using the ESM3 model.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "1TwEAW_LSNZZ"
},
"outputs": [],
"source": [
"# @title Input API keys, then hit `Runtime` -> `Run all`\n",
"# @markdown Our hosted service that provides access to the full suite of ESM3 models.\n",
"# @markdown To utilize the Forge API, users must first agree to the [Terms of Service](https://forge.evolutionaryscale.ai/termsofservice) and generate an access token via the [Forge console](https://forge.evolutionaryscale.ai/console).\n",
"# @markdown The console also provides a comprehensive list of models available to each user.\n",
"\n",
"import os\n",
"\n",
"# @markdown ### Authentication\n",
"# @markdown Paste your token from the [Forge console](https://forge.evolutionaryscale.ai/console)\n",
"forge_token = \"\" # @param {type:\"string\"}\n",
"os.environ[\"ESM_API_KEY\"] = forge_token\n",
"\n",
"# @markdown ### Model Selection\n",
"# @markdown Enter the model name from the [Forge console page](https://forge.evolutionaryscale.ai/console) that you would like to use:\n",
"model_name = \"esm3-medium-2024-08\" # @param {type:\"string\"}\n",
"\n",
"# @markdown ### Input Structure\n",
"pdb_code = \"\" # @param {type:\"string\"}\n",
"chain = \"detect\" # @param {type:\"string\"}\n",
"# @markdown Enter PDB code or leave blank to upload file\n",
"# @markdown Specify a chain if uploading a complex\n",
"\n",
"# @markdown ### Design Parameters\n",
"temperature = 0.1 # @param {type:\"slider\", min:0.0, max:1.0, step:0.01}\n",
"num_sequences = 8 # @param {type:\"integer\"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "_942E63WS8-U"
},
"outputs": [],
"source": [
"# @title Install dependencies\n",
"import os\n",
"\n",
"os.system(\"pip install git+https://github.com/evolutionaryscale/esm\")\n",
"os.system(\n",
" \"pip install pydssp pygtrie dna-features-viewer py3dmol nest-asyncio ipywidgets\"\n",
")\n",
"\n",
"import nest_asyncio # noqa: E402\n",
"\n",
"nest_asyncio.apply()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "jXl61b-zTIsp"
},
"outputs": [],
"source": [
"# @title Run Inverse Folding\n",
"import numpy as np\n",
"from esm.sdk.api import ESMProtein, ESMProteinError, GenerationConfig\n",
"from esm.widgets.utils.clients import get_forge_client\n",
"from google.colab import files\n",
"from IPython.display import HTML\n",
"\n",
"\n",
"def get_pdb(pdb_code=\"\"):\n",
" if pdb_code is None or pdb_code == \"\":\n",
" upload_dict = files.upload()\n",
" pdb_string = upload_dict[list(upload_dict.keys())[0]]\n",
" with open(\"tmp.pdb\", \"wb\") as out:\n",
" out.write(pdb_string)\n",
" return \"tmp.pdb\"\n",
" else:\n",
" os.system(f\"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb\")\n",
" return f\"{pdb_code}.pdb\"\n",
"\n",
"\n",
"print(\"Loading structure...\")\n",
"pdb_path = get_pdb(pdb_code)\n",
"\n",
"# Create protein object\n",
"protein = ESMProtein.from_pdb(pdb_path, chain_id=chain)\n",
"protein.sequence = None\n",
"\n",
"print(\"Running inverse folding...\")\n",
"client = get_forge_client(model_name)\n",
"generations = client.batch_generate(\n",
" inputs=[protein] * num_sequences,\n",
" configs=[GenerationConfig(track=\"sequence\", temperature=temperature)]\n",
" * num_sequences,\n",
")\n",
"\n",
"if isinstance(protein, ESMProteinError):\n",
" raise RuntimeError(f\"Error: {str(protein)}\")\n",
"\n",
"errors: list[ESMProteinError] = []\n",
"sequences: list[str] = []\n",
"for i, protein in enumerate(generations):\n",
" if isinstance(protein, ESMProteinError):\n",
" errors.append((i, protein))\n",
" else:\n",
" sequences.append(protein.sequence)\n",
"\n",
"\n",
"def calculate_conservation_scores(sequences: list[str]) -> np.ndarray:\n",
" array = np.array([list(seq) for seq in sequences], dtype=\"S1\")\n",
" array = array.view(np.uint8) - ord(\"A\")\n",
"\n",
" # Create a 2D array of counts\n",
" max_range = 26\n",
" counts = np.zeros((max_range + 1, array.shape[1]), dtype=int)\n",
" for col in range(array.shape[1]):\n",
" count = np.bincount(array[:, col], minlength=max_range + 1)\n",
" counts[:, col] = count\n",
" counts = counts.T\n",
"\n",
" # Calculate entropy (-sum(p log p))\n",
" probabilities = counts / counts.sum(axis=1, keepdims=True)\n",
" entropy = -np.sum(probabilities * np.log(probabilities + 1e-9), axis=1)\n",
"\n",
" # Convert to conservation score (1 - normalized entropy)\n",
" max_entropy = np.log(256)\n",
" # Magic constant to make displaying non-conserved residues more apparent\n",
" conservation_scores = np.maximum(0, 0.5 - (entropy / max_entropy)) / 0.5\n",
"\n",
" return conservation_scores\n",
"\n",
"\n",
"def display_sequences(sequences: list[str]):\n",
" conservation_scores = calculate_conservation_scores(sequences)\n",
" html_output = '<pre style=\"line-height:1.0;letter-spacing:3px;font-family:monospace;margin:0;padding:0\">'\n",
" for sequence in sequences:\n",
" for j, residue in enumerate(sequence):\n",
" # Add padding for alignment and color the background\n",
" html_output += f'<span style=\"background-color: rgba(9, 121, 105,{conservation_scores[j]})\">{residue}</span>'\n",
" html_output += \"<br>\"\n",
" html_output += \"</pre>\"\n",
" display(HTML(html_output))\n",
"\n",
"\n",
"display_sequences(sequences)\n",
"\n",
"for i, error in errors:\n",
" print(f\"Error code {error.error_code} at index {i}: {error.error_msg}\")"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}