mirror of
https://github.com/evolutionaryscale/esm.git
synced 2026-06-04 17:14:23 +08:00
Add guided generation (#228)
This commit is contained in:
405
cookbook/tutorials/5_guided_generation.ipynb
Normal file
405
cookbook/tutorials/5_guided_generation.ipynb
Normal file
@@ -0,0 +1,405 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Guided Generation with ESM3\n",
|
||||
"\n",
|
||||
"Guided generation is a powerful tool that allows you to sample outputs out of ESM3 that maximize any kind of score function.\n",
|
||||
"\n",
|
||||
"For example, you may want to\n",
|
||||
"1. Guide generations towards higher quality metrics like pTM\n",
|
||||
"2. Constrain the distribution of outputs to have certain amino acid frequencies or structural attributes\n",
|
||||
"3. Minimize a biophysical energy function\n",
|
||||
"4. Use experimental screening data to guide designs with a regression model\n",
|
||||
"\n",
|
||||
"As long as your scoring function takes a protein as input and outputs a single score, you can use it to guide designs. To accomplish this, we use an implementation of derivative-free guidance inspired by Soft Value-Based Decoding described in [Uehara, et al 2024](https://arxiv.org/abs/2408.08252).\n",
|
||||
"\n",
|
||||
"In this notebook we will walk through a few examples to illustrate how to use guided generation. \n",
|
||||
"\n",
|
||||
"1. Guide towards high pTM for improved generation quality\n",
|
||||
"\n",
|
||||
"2. Generate a protein with no Cysteine residues\n",
|
||||
"\n",
|
||||
"3. Maximize protein globularity by minimizing the radius of gyration\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Imports"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title **IMPORTANT:** run this cell first before doing 'Runtime →> Run all'\n",
|
||||
"#@markdown - The latest update to Google Colab broke numpy; this is a temporary patch.\n",
|
||||
"#@markdown - Note after running this cell, the session will crash (this is normal).\n",
|
||||
"\n",
|
||||
"import os, numpy, signal\n",
|
||||
"\n",
|
||||
"if numpy.__version__ != '1.26.4':\n",
|
||||
" print(f\"Current numpy version {numpy.__version__} is incorrect. Installing 1.26.4...\")\n",
|
||||
" os.system(\"'pip uninstall -y numpy\")\n",
|
||||
" os.system(\"pip install numpy==1.26.4\")\n",
|
||||
" # Restart the runtime using os.kill\n",
|
||||
" os.kill(os. getpid(), signal.SIGKILL)\n",
|
||||
"else:\n",
|
||||
" print (\"Numpy version is correct (1.26.4)\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install git+https://github.com/evolutionaryscale/esm.git\n",
|
||||
"!pip install py3dmol"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import biotite.structure as bs\n",
|
||||
"import py3Dmol\n",
|
||||
"\n",
|
||||
"from esm.models.esm3 import ESM3\n",
|
||||
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
|
||||
"from esm.sdk.experimental import (\n",
|
||||
" ESM3GuidedDecoding,\n",
|
||||
" GuidedDecodingScoringFunction,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Creating a scoring function\n",
|
||||
"\n",
|
||||
"To get started with the guided generation API the only thing you need is to create a callable class that inherits from `GuidedDecodingScoringFunction`. This class should receive as input an `ESMProtein` object and output a numerical score.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"For example, one of the computational metrics we can use to measure the quality of a generated protein structure is the Predicted Template Modelling (pTM) score, so we'll use it to create a `PTMScoringFunction`.\n",
|
||||
"\n",
|
||||
"Fortunately for us, every time we generate a protein using ESM3 (either locally or on Forge) we also get its pTM, so all our class needs to do when its called is to return the `ptm` attribute of its input."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create scoring function (e.g. PTM scoring function)\n",
|
||||
"class PTMScoringFunction(GuidedDecodingScoringFunction):\n",
|
||||
" def __call__(self, protein: ESMProtein) -> float:\n",
|
||||
" # Minimal example of a scoring function that scores proteins based on their pTM score\n",
|
||||
" # Given that ESM3 already has a pTM prediction head, we can directly access the pTM score\n",
|
||||
" assert protein.ptm is not None, \"Protein must have pTM scores to be scored\"\n",
|
||||
" return float(protein.ptm)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Initialize your client\n",
|
||||
"\n",
|
||||
"The guided generation is compatible with both local inference using the `ESM3` class and remote inference with the Forge client"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# To use the tokenizers and the open model you'll need to login into Hugging Face\n",
|
||||
"\n",
|
||||
"from huggingface_hub import notebook_login\n",
|
||||
"notebook_login()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"## Locally with ESM3-open\n",
|
||||
"model = ESM3.from_pretrained().to(\"cuda\")\n",
|
||||
"\n",
|
||||
"## On Forge with larger ESM3 models\n",
|
||||
"# from getpass import getpass\n",
|
||||
"\n",
|
||||
"# from esm.sdk import client\n",
|
||||
"\n",
|
||||
"# token = getpass(\"Token from Forge console: \")\n",
|
||||
"# model = client(model=\"esm3-open\", url=\"https://forge.evolutionaryscale.ai\", token=token)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## pTM Guided Generation\n",
|
||||
"\n",
|
||||
"Once your scoring function is defined and you have initialized your model you can create an `ESM3GuidedDecoding` instance to sample from it"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ptm_guided_decoding = ESM3GuidedDecoding(\n",
|
||||
" client=model, scoring_function=PTMScoringFunction()\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Start from a fully masked protein\n",
|
||||
"PROTEIN_LENGTH = 256\n",
|
||||
"starting_protein = ESMProtein(sequence=\"_\" * PROTEIN_LENGTH)\n",
|
||||
"\n",
|
||||
"# Call guided_generate\n",
|
||||
"generated_protein = ptm_guided_decoding.guided_generate(\n",
|
||||
" protein=starting_protein,\n",
|
||||
" num_decoding_steps=len(starting_protein) // 8,\n",
|
||||
" num_samples_per_step=10,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Compare against baseline with no guidance\n",
|
||||
"\n",
|
||||
"First we are going to sample a protein generated without any guidance. This means that, when not providing pTM guidance, we could be sampling proteins that have no clear structure."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Generate a protein WITHOUT guidance\n",
|
||||
"generated_protein_no_guided: ESMProtein = model.generate(\n",
|
||||
" input=starting_protein,\n",
|
||||
" config=GenerationConfig(track=\"sequence\", num_steps=len(starting_protein) // 8),\n",
|
||||
") # type: ignore\n",
|
||||
"\n",
|
||||
"# Fold\n",
|
||||
"generated_protein_no_guided: ESMProtein = model.generate(\n",
|
||||
" input=generated_protein_no_guided,\n",
|
||||
" config=GenerationConfig(track=\"structure\", num_steps=1),\n",
|
||||
") # type: ignore"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create a 1x2 grid of viewers (1 row, 2 columns)\n",
|
||||
"view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))\n",
|
||||
"\n",
|
||||
"# Convert ESMProtein objects to ProteinChain objects\n",
|
||||
"protein_chain1 = generated_protein_no_guided.to_protein_chain()\n",
|
||||
"protein_chain2 = generated_protein.to_protein_chain()\n",
|
||||
"\n",
|
||||
"# Add models to respective panels\n",
|
||||
"view.addModel(protein_chain1.to_pdb_string(), \"pdb\", viewer=(0, 0))\n",
|
||||
"view.addModel(protein_chain2.to_pdb_string(), \"pdb\", viewer=(0, 1))\n",
|
||||
"\n",
|
||||
"# Set styles for each protein\n",
|
||||
"view.setStyle({}, {\"cartoon\": {\"color\": \"spectrum\"}}, viewer=(0, 0))\n",
|
||||
"view.setStyle({}, {\"cartoon\": {\"color\": \"spectrum\"}}, viewer=(0, 1))\n",
|
||||
"\n",
|
||||
"# Zoom and center the view\n",
|
||||
"view.zoomTo()\n",
|
||||
"view.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Generate a Protein with No Cysteines\n",
|
||||
"\n",
|
||||
"Guided generation is not constrained to structural metrics, you can also use it to guide the sequence generation.\n",
|
||||
"\n",
|
||||
"For example, we can create a `NoCysteineScoringFunction` that penalizes the protein if it contains Cysteine residues"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class NoCysteineScoringFunction(GuidedDecodingScoringFunction):\n",
|
||||
" def __call__(self, protein: ESMProtein) -> float:\n",
|
||||
" # Penalize proteins that contain cysteine\n",
|
||||
" assert protein.sequence is not None, \"Protein must have a sequence to be scored\"\n",
|
||||
" # Note that we use a negative score here, to discourage the presence of cysteine\n",
|
||||
" return -protein.sequence.count(\"C\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"no_cysteine_guided_decoding = ESM3GuidedDecoding(\n",
|
||||
" client=model, scoring_function=NoCysteineScoringFunction()\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"no_cysteine_protein = no_cysteine_guided_decoding.guided_generate(\n",
|
||||
" protein=starting_protein,\n",
|
||||
" num_decoding_steps=len(starting_protein) // 8,\n",
|
||||
" num_samples_per_step=10,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's check our sequence!\n",
|
||||
"\n",
|
||||
"If guided generation converged to `score == 0.00`, the resulting protein should contain no Cysteine residues"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"assert no_cysteine_protein.sequence is not None, \"Protein must have a sequence\"\n",
|
||||
"print(no_cysteine_protein.sequence)\n",
|
||||
"print(f\"Number of cysteine residues: {no_cysteine_protein.sequence.count('C')}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Maximize Globularity\n",
|
||||
"\n",
|
||||
"We use the radius of gyration as a proxy to maximize globularity, we also encourage generations to have high pTM"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class RadiousOfGyrationScoringFunction(GuidedDecodingScoringFunction):\n",
|
||||
" def __call__(self, protein: ESMProtein) -> float:\n",
|
||||
" score = -1 * self.radius_of_gyration(protein)\n",
|
||||
"\n",
|
||||
" assert protein.ptm is not None, \"Protein must have pTM scores to be scored\"\n",
|
||||
" if protein.ptm < 0.5:\n",
|
||||
" # Penalize proteins with low pTM scores\n",
|
||||
" score = score * 2\n",
|
||||
"\n",
|
||||
" return score\n",
|
||||
"\n",
|
||||
" @staticmethod\n",
|
||||
" def radius_of_gyration(protein: ESMProtein) -> float:\n",
|
||||
" protein_chain = protein.to_protein_chain()\n",
|
||||
" arr = protein_chain.atom_array_no_insertions\n",
|
||||
" return bs.gyration_radius(arr)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"radius_guided_decoding = ESM3GuidedDecoding(\n",
|
||||
" client=model, scoring_function=RadiousOfGyrationScoringFunction()\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"radius_guided_protein = radius_guided_decoding.guided_generate(\n",
|
||||
" protein=starting_protein,\n",
|
||||
" num_decoding_steps=len(starting_protein) // 8,\n",
|
||||
" num_samples_per_step=10,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"view = py3Dmol.view(width=800, height=400)\n",
|
||||
"view.addModel(radius_guided_protein.to_pdb_string(), \"pdb\")\n",
|
||||
"view.setStyle({\"cartoon\": {\"color\": \"spectrum\"}})\n",
|
||||
"view.zoomTo()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -1 +1,2 @@
|
||||
__version__ = "3.1.5post1"
|
||||
__version__ = "3.1.7"
|
||||
|
||||
|
||||
209
esm/sdk/experimental.py
Normal file
209
esm/sdk/experimental.py
Normal file
@@ -0,0 +1,209 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import attr
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from esm.models.esm3 import ESM3
|
||||
from esm.sdk.api import (
|
||||
ESM3InferenceClient,
|
||||
ESMProtein,
|
||||
ESMProteinError,
|
||||
ESMProteinTensor,
|
||||
SamplingConfig,
|
||||
SamplingTrackConfig,
|
||||
)
|
||||
from esm.sdk.forge import ESM3ForgeInferenceClient
|
||||
from esm.tokenization import get_esm3_model_tokenizers
|
||||
|
||||
|
||||
class GuidedDecodingScoringFunction(ABC):
|
||||
@abstractmethod
|
||||
def __call__(self, protein: ESMProtein) -> float:
|
||||
pass
|
||||
|
||||
|
||||
class ESM3GuidedDecoding:
|
||||
"""This class can be used to perform derivative-free guided decoding, based on
|
||||
the method described in "Derivative-Free Guidance in Continuous and Discrete Diffusion Models with Soft Value-Based Decoding"
|
||||
https://arxiv.org/abs/2408.08252
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: ESM3InferenceClient,
|
||||
scoring_function: GuidedDecodingScoringFunction,
|
||||
):
|
||||
if isinstance(client, ESM3):
|
||||
self.tokenizers = client.tokenizers
|
||||
elif isinstance(client, ESM3ForgeInferenceClient):
|
||||
self.tokenizers = get_esm3_model_tokenizers(client.model)
|
||||
else:
|
||||
raise ValueError(
|
||||
"client must be an instance of ESM3 or ESM3ForgeInferenceClient"
|
||||
)
|
||||
|
||||
self.client = client
|
||||
self.scoring_function = scoring_function
|
||||
|
||||
def guided_generate(
|
||||
self,
|
||||
protein: ESMProtein,
|
||||
num_decoding_steps: int,
|
||||
num_samples_per_step: int,
|
||||
denoised_prediction_temperature: float = 0.0,
|
||||
track: str = "sequence",
|
||||
verbose: bool = True,
|
||||
) -> ESMProtein:
|
||||
protein_tensor = self.client.encode(protein)
|
||||
|
||||
assert not isinstance(protein_tensor, ESMProteinError)
|
||||
|
||||
if track == "structure":
|
||||
protein_tensor = self.maybe_add_default_structure_tokens(protein_tensor)
|
||||
|
||||
num_masked_positions = self.get_number_of_masked_positions(
|
||||
protein_tensor, track=track
|
||||
)
|
||||
num_positions_to_unmask = num_masked_positions // num_decoding_steps
|
||||
|
||||
current_score = -1
|
||||
|
||||
if verbose:
|
||||
pbar = tqdm(range(num_decoding_steps), desc="Current score: -1")
|
||||
else:
|
||||
pbar = range(num_decoding_steps)
|
||||
|
||||
for step in pbar:
|
||||
if step == num_decoding_steps - 1:
|
||||
# At the last step, unmask all remaining positions
|
||||
num_positions_to_unmask = self.get_number_of_masked_positions(
|
||||
protein_tensor, track=track
|
||||
)
|
||||
|
||||
samples = []
|
||||
scores = []
|
||||
for _ in range(num_samples_per_step):
|
||||
sample = self.randomly_unmask_positions(
|
||||
protein_tensor, num_positions_to_unmask, track=track
|
||||
)
|
||||
scores.append(
|
||||
self.reward_function(
|
||||
sample,
|
||||
denoised_prediction_temperature=denoised_prediction_temperature,
|
||||
)
|
||||
)
|
||||
samples.append(sample)
|
||||
|
||||
# Select best scoring sample
|
||||
best_sample = samples[scores.index(max(scores))]
|
||||
current_score = max(scores)
|
||||
protein_tensor = best_sample
|
||||
|
||||
if verbose:
|
||||
pbar.set_description(f"Current score: {current_score:.2f}") # type: ignore
|
||||
|
||||
# Fully predict and decode final protein
|
||||
protein_tensor_output = self.client.forward_and_sample(
|
||||
protein_tensor,
|
||||
SamplingConfig(
|
||||
sequence=SamplingTrackConfig(temperature=0.0),
|
||||
structure=SamplingTrackConfig(temperature=0.0),
|
||||
),
|
||||
)
|
||||
|
||||
assert not isinstance(protein_tensor_output, ESMProteinError)
|
||||
protein_tensor = protein_tensor_output.protein_tensor
|
||||
|
||||
decoded_protein = self.client.decode(protein_tensor)
|
||||
assert not isinstance(decoded_protein, ESMProteinError)
|
||||
return decoded_protein
|
||||
|
||||
def reward_function(
|
||||
self,
|
||||
protein_tensor: ESMProteinTensor,
|
||||
denoised_prediction_temperature: float = 0.0,
|
||||
) -> float:
|
||||
denoised_protein = self.predict_denoised(
|
||||
protein_tensor, temperature=denoised_prediction_temperature
|
||||
)
|
||||
return self.scoring_function(denoised_protein)
|
||||
|
||||
def get_number_of_masked_positions(
|
||||
self, protein_tensor: ESMProteinTensor, track: str = "sequence"
|
||||
) -> int:
|
||||
assert isinstance(protein_tensor, ESMProteinTensor)
|
||||
track_tensor = getattr(protein_tensor, track)
|
||||
track_tokenizer = getattr(self.tokenizers, track)
|
||||
is_mask = track_tensor == track_tokenizer.mask_token_id
|
||||
return is_mask.sum().item() # type: ignore
|
||||
|
||||
def randomly_unmask_positions(
|
||||
self,
|
||||
protein_tensor: ESMProteinTensor,
|
||||
num_positions_to_unmask: int,
|
||||
temperature: float = 1.0,
|
||||
track: str = "sequence",
|
||||
) -> ESMProteinTensor:
|
||||
track_tensor = getattr(protein_tensor, track)
|
||||
assert track_tensor is not None
|
||||
protein_tensor = attr.evolve(protein_tensor)
|
||||
setattr(protein_tensor, track, track_tensor.clone())
|
||||
|
||||
track_tensor = getattr(protein_tensor, track)
|
||||
track_tokenizer = getattr(self.tokenizers, track)
|
||||
|
||||
is_mask = track_tensor == track_tokenizer.mask_token_id
|
||||
num_masked_positions = is_mask.sum().item()
|
||||
|
||||
if num_positions_to_unmask > num_masked_positions:
|
||||
num_positions_to_unmask = num_masked_positions # type: ignore
|
||||
|
||||
mask_indices = is_mask.nonzero(as_tuple=False)
|
||||
mask_indices = mask_indices[torch.randperm(mask_indices.size(0))]
|
||||
mask_indices = mask_indices[:num_positions_to_unmask]
|
||||
|
||||
sampling_config = SamplingConfig()
|
||||
setattr(sampling_config, track, SamplingTrackConfig(temperature=temperature))
|
||||
|
||||
denoised_protein_tensor_output = self.client.forward_and_sample(
|
||||
protein_tensor, sampling_configuration=sampling_config
|
||||
)
|
||||
assert not isinstance(denoised_protein_tensor_output, ESMProteinError)
|
||||
denoised_protein_tensor = denoised_protein_tensor_output.protein_tensor
|
||||
output_track_tensor = getattr(denoised_protein_tensor, track)
|
||||
assert output_track_tensor is not None
|
||||
track_tensor[mask_indices] = output_track_tensor[mask_indices]
|
||||
setattr(protein_tensor, track, track_tensor)
|
||||
|
||||
return protein_tensor
|
||||
|
||||
def predict_denoised(
|
||||
self, protein_tensor: ESMProteinTensor, temperature: float = 0.0
|
||||
) -> ESMProtein:
|
||||
denoised_protein_tensor_output = self.client.forward_and_sample(
|
||||
protein_tensor,
|
||||
sampling_configuration=SamplingConfig(
|
||||
sequence=SamplingTrackConfig(temperature=temperature),
|
||||
structure=SamplingTrackConfig(temperature=temperature),
|
||||
),
|
||||
)
|
||||
assert not isinstance(denoised_protein_tensor_output, ESMProteinError)
|
||||
denoised_protein_tensor = denoised_protein_tensor_output.protein_tensor
|
||||
denoised_protein = self.client.decode(denoised_protein_tensor)
|
||||
assert not isinstance(denoised_protein, ESMProteinError)
|
||||
return denoised_protein
|
||||
|
||||
def maybe_add_default_structure_tokens(
|
||||
self, protein_tensor: ESMProteinTensor
|
||||
) -> ESMProteinTensor:
|
||||
empty_protein_tensor = ESMProteinTensor.empty(
|
||||
len(protein_tensor) - 2,
|
||||
tokenizers=self.tokenizers,
|
||||
device=protein_tensor.device,
|
||||
)
|
||||
if protein_tensor.structure is None:
|
||||
setattr(protein_tensor, "structure", empty_protein_tensor.structure)
|
||||
else:
|
||||
print("Warning: structure already exists in protein_tensor")
|
||||
return protein_tensor
|
||||
@@ -134,6 +134,7 @@ def _make_masked_inputs(
|
||||
track: str, sequence_length: int, tokenizers: TokenizerCollectionProtocol
|
||||
):
|
||||
get_tokenizer: Callable[[str], EsmTokenizerBase] = lambda s: getattr(tokenizers, s)
|
||||
has_tokenizer: Callable[[str], bool] = lambda s: hasattr(tokenizers, s)
|
||||
|
||||
if track == "coordinates":
|
||||
dims = (sequence_length, 3, 3)
|
||||
@@ -155,12 +156,15 @@ def _make_masked_inputs(
|
||||
masked_tokens = torch.full(dims, 0.0)
|
||||
elif track == "attention_mask":
|
||||
masked_tokens = torch.full(dims, 1, dtype=torch.bool)
|
||||
else:
|
||||
elif has_tokenizer(track):
|
||||
masked_tokens = torch.full(
|
||||
dims, get_tokenizer(track).mask_token_id, dtype=torch.long
|
||||
)
|
||||
masked_tokens[0] = get_tokenizer(track).bos_token_id
|
||||
masked_tokens[-1] = get_tokenizer(track).eos_token_id
|
||||
else:
|
||||
# Does not know how to create the dummy all masked input.
|
||||
return None
|
||||
|
||||
return masked_tokens
|
||||
|
||||
@@ -173,15 +177,32 @@ def _stack_protein_tensors(
|
||||
) -> _BatchedESMProteinTensor:
|
||||
o = _BatchedESMProteinTensor()
|
||||
|
||||
def _maybe_mock_input(fn, t, l):
|
||||
if t is not None:
|
||||
return t
|
||||
|
||||
# Try create dummy masked input for this prompt.
|
||||
t = _make_masked_inputs(fn, l, tokenizers)
|
||||
if t is not None:
|
||||
t = t.to(device)
|
||||
|
||||
return t
|
||||
|
||||
def _stack_field(fn: str):
|
||||
tensors = [getattr(tokens, fn) for tokens in input_tokens]
|
||||
|
||||
# Create all mask mock inputs for any tensors that are None.
|
||||
tensors = [
|
||||
t if t is not None else _make_masked_inputs(fn, l, tokenizers).to(device)
|
||||
for t, l in zip(tensors, sequence_lengths)
|
||||
_maybe_mock_input(fn, t, l) for t, l in zip(tensors, sequence_lengths)
|
||||
]
|
||||
|
||||
# Handle any track that has all None as the input.
|
||||
# We can't meaningfully stack tensors in this case, so simply batched
|
||||
# them as None in _BatchedESMProteinTensor.
|
||||
if all([t is None for t in tensors]):
|
||||
setattr(o, fn, None)
|
||||
return
|
||||
|
||||
if fn == "coordinates":
|
||||
mask_token_id = torch.inf
|
||||
else:
|
||||
@@ -191,7 +212,8 @@ def _stack_protein_tensors(
|
||||
o,
|
||||
fn,
|
||||
stack_variable_length_tensors(
|
||||
sequences=tensors, constant_value=mask_token_id
|
||||
sequences=tensors, # type: ignore
|
||||
constant_value=mask_token_id,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
[project]
|
||||
name = "esm"
|
||||
version = "3.1.6"
|
||||
version = "3.1.7"
|
||||
description = "EvolutionaryScale open model repository"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
license = {file = "LICENSE.txt"}
|
||||
license = { file = "LICENSE.txt" }
|
||||
|
||||
authors = [
|
||||
{name = "EvolutionaryScale Team"}
|
||||
]
|
||||
authors = [{ name = "EvolutionaryScale Team" }]
|
||||
|
||||
maintainers = [
|
||||
{name = "Zeming Lin", email = "zeming+esm@evolutionaryscale.ai" }
|
||||
{ name = "Zeming Lin", email = "zeming+esm@evolutionaryscale.ai" },
|
||||
]
|
||||
|
||||
classifiers = [
|
||||
@@ -36,11 +34,11 @@ dependencies = [
|
||||
"pandas",
|
||||
"cloudpathlib",
|
||||
"tenacity",
|
||||
"zstd"
|
||||
"zstd",
|
||||
]
|
||||
|
||||
[tool.setuptools]
|
||||
package-dir = {"" = "."}
|
||||
package-dir = { "" = "." }
|
||||
include-package-data = true
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
|
||||
Reference in New Issue
Block a user