diff --git a/README.md b/README.md
index 6c7615a..e01bf31 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,8 @@
# ESM3
+
[ESM3](https://www.evolutionaryscale.ai/papers/esm3-simulating-500-million-years-of-evolution-with-a-language-model) is a frontier generative model for biology, able to jointly reason across three fundamental biological properties of proteins: sequence, structure, and function. These three data modalities are represented as tracks of discrete tokens at the input and output of ESM3. You can present the model with a combination of partial inputs across the tracks, and ESM3 will provide output predictions for all the tracks.
-ESM3 is a *generative* masked language model. You can prompt it with partial sequence, structure, and function keywords, and iteratively sample masked positions until all positions are unmasked. This iterative sampling is what the `.generate()` function does.
+ESM3 is a _generative_ masked language model. You can prompt it with partial sequence, structure, and function keywords, and iteratively sample masked positions until all positions are unmasked. This iterative sampling is what the `.generate()` function does.
@@ -13,7 +14,6 @@ Here we present `esm3-open-small`. With 1.4B parameters it is the smallest and f
ESM3-open is available under a [non-commercial license](https://www.evolutionaryscale.ai/policies/community-license-agreement), reproduced under `LICENSE.md`.
Visit our [Discussions page](https://github.com/evolutionaryscale/esm/discussions) to get in touch, provide feedback, ask questions or share your experience with ESM3!
-
## Quickstart for ESM3-open
```
@@ -61,10 +61,12 @@ Let's explore some more advanced prompting with the help of our [notebooks and s
[
](https://colab.research.google.com/github/evolutionaryscale/esm/blob/main/examples/gfp_design.ipynb)
We also provide example scripts that show common workflows under `examples/`:
-* [local_generate.py](./examples/local_generate.py) shows how simple and elegant common tasks are: it shows folding, inverse folding and chain of thought generation, all by calling just `model.generate()` for iterative decoding.
-* [seqfun_struct.py](./examples/seqfun_struct.py) shows direct use of the model as a standard pytorch model with a simple model `forward` call.
+
+- [local_generate.py](./examples/local_generate.py) shows how simple and elegant common tasks are: it shows folding, inverse folding and chain of thought generation, all by calling just `model.generate()` for iterative decoding.
+- [seqfun_struct.py](./examples/seqfun_struct.py) shows direct use of the model as a standard pytorch model with a simple model `forward` call.
## Forge: Access to larger ESM3 models
+
You can apply for beta access to the full family of larger and higher capability ESM3 models at [EvolutionaryScale Forge](https://forge.evolutionaryscale.ai).
We encourage users to interact with the Forge API through the python `esm` library instead of the command line.
@@ -72,14 +74,16 @@ The python interface enables you to interactively load proteins, build prompts,
with the `ESMProtein` and config classes used to interact with the local model.
In any example script try to replace a local `ESM3` model with a Forge API client:
+
```py
# Instead of loading the model locally on your machine:
model: ESM3InferenceClient = ESM3.from_pretrained("esm3_sm_open_v1").to("cuda") # or "cpu"
# just replace the line with this:
-model: ESM3InferenceClient = esm.sdk.client("esm3-md-v1", token="")
+model: ESM3InferenceClient = esm.sdk.client("esm3-medium-2024-08", token="")
# and now you're interfacing with the model running on our remote servers.
...
```
+
and the exact same code will work.
This enables a seamless transition from smaller and faster models, to our large 98B protein language models for protein design work.
@@ -96,7 +100,6 @@ The core tenets of our framework are
With this in mind, we have performed a variety of mitigations for `esm3-sm-open-v1`, detailed in our [paper](https://www.evolutionaryscale.ai/papers/esm3-simulating-500-million-years-of-evolution-with-a-language-model)
-
## License
**The Big Picture:**
@@ -113,5 +116,4 @@ With this in mind, we have performed a variety of mitigations for `esm3-sm-open-
3. You **can publish, share and adapt** the EvolutionaryScale AI Model and its outputs for **non-commercial purposes** in accordance with the Community License Agreement, including a **non-commercial restriction** on the adapted model.
-
Please read our non-commercial [Community License Agreement](https://www.evolutionaryscale.ai/policies/community-license-agreement) reproduced under [./LICENSE.md](LICENSE.md) before using ESM3.
diff --git a/esm/__init__.py b/esm/__init__.py
index 8e10cb4..e94f36f 100644
--- a/esm/__init__.py
+++ b/esm/__init__.py
@@ -1 +1 @@
-__version__ = "3.0.4"
+__version__ = "3.0.5"
diff --git a/esm/models/esm3.py b/esm/models/esm3.py
index 2ebcf41..f383dff 100644
--- a/esm/models/esm3.py
+++ b/esm/models/esm3.py
@@ -34,7 +34,10 @@ from esm.tokenization import (
)
from esm.utils import encoding
from esm.utils.constants import esm3 as C
-from esm.utils.constants.models import ESM3_OPEN_SMALL
+from esm.utils.constants.models import (
+ ESM3_OPEN_SMALL,
+ normalize_model_name,
+)
from esm.utils.decoding import decode_protein_tensor
from esm.utils.generation import (
_batch_forward,
@@ -240,7 +243,8 @@ class ESM3(nn.Module, ESM3InferenceClient):
) -> ESM3:
from esm.pretrained import load_local_model
- if model_name not in [ESM3_OPEN_SMALL]:
+ model_name = normalize_model_name(model_name)
+ if not model_name:
raise ValueError(f"Model name {model_name} is not a valid ESM3 model name.")
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -254,6 +258,10 @@ class ESM3(nn.Module, ESM3InferenceClient):
def device(self):
return next(self.parameters()).device
+ @property
+ def raw_model(self):
+ return self
+
def get_structure_encoder(self) -> StructureTokenEncoder:
if self._structure_encoder is None:
self._structure_encoder = self.structure_encoder_fn(self.device)
diff --git a/esm/sdk/api.py b/esm/sdk/api.py
index 4e4a0d4..2e1a0fd 100644
--- a/esm/sdk/api.py
+++ b/esm/sdk/api.py
@@ -202,6 +202,7 @@ class ESMProteinTensor(ProteinType):
@define
class ESMProteinError(Exception, ProteinType):
+ error_code: int # Error code follows HTTP convention, i.e., 404 NotFoundError, 500 InternalError.
error_msg: str
@@ -343,3 +344,8 @@ class ESM3InferenceClient(ABC):
# This is the way for power users to run ESM3. We hope to design this in a way to enable high throughput
# inference, as well as arbitrary chain-of-though invocations of ESM3.
raise NotImplementedError
+
+ @property
+ def raw_model(self):
+ # Get underlying esm3 model of an inference client.
+ raise NotImplementedError
diff --git a/esm/sdk/forge.py b/esm/sdk/forge.py
index 820fb18..18592c5 100644
--- a/esm/sdk/forge.py
+++ b/esm/sdk/forge.py
@@ -1,9 +1,16 @@
import asyncio
+from functools import wraps
from typing import Sequence
from urllib.parse import urljoin
import requests
import torch
+from tenacity import (
+ retry,
+ retry_if_result,
+ stop_after_attempt,
+ wait_exponential,
+)
from esm.sdk.api import (
ESM3InferenceClient,
@@ -30,6 +37,20 @@ def _list_to_function_annotations(l) -> list[FunctionAnnotation] | None:
return [FunctionAnnotation(*t) for t in l]
+def retry_if_specific_error(exception):
+ """
+ We only retry on specific errors.
+ Currently we retry for 502 (bad gateway) and 429 (rate limit)
+ """
+ return isinstance(exception, ESMProteinError) and exception.error_code in {429, 502}
+
+
+def log_retry_attempt(retry_state):
+ print(
+ f"Retrying... Attempt {retry_state.attempt_number} after {retry_state.next_action.sleep}s due to: {retry_state.outcome.result()}"
+ )
+
+
class ESM3ForgeInferenceClient(ESM3InferenceClient):
def __init__(
self,
@@ -37,24 +58,58 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
url: str = "https://forge.evolutionaryscale.ai",
token: str = "",
request_timeout: int | None = None,
+ min_retry_wait: int = 1,
+ max_retry_wait: int = 10,
+ max_retry_attempts: int = 5,
):
if token == "":
raise RuntimeError(
"Please provide a token to connect to Forge via token=YOUR_API_TOKEN_HERE"
)
- self.model = model
+ self.model = model # Name of the model to run.
self.url = url
self.token = token
self.headers = {"Authorization": f"Bearer {self.token}"}
self.request_timeout = request_timeout
+ self.min_retry_wait = min_retry_wait
+ self.max_retry_wait = max_retry_wait
+ self.max_retry_attempts = max_retry_attempts
+ @staticmethod
+ def retry_decorator(func):
+ """
+ A static method that returns a retry decorator. This decorator uses the
+ instance's retry settings.
+ """
+
+ @wraps(func)
+ def wrapper(instance, *args, **kwargs):
+ retry_decorator = retry(
+ retry=retry_if_result(retry_if_specific_error),
+ wait=wait_exponential(
+ multiplier=1,
+ min=instance.min_retry_wait,
+ max=instance.max_retry_wait,
+ ),
+ stop=stop_after_attempt(instance.max_retry_attempts),
+ before_sleep=log_retry_attempt,
+ )
+ # Apply the retry decorator to the function
+ return retry_decorator(func)(instance, *args, **kwargs)
+
+ return wrapper
+
+ @retry_decorator
def generate(self, input: ProteinType, config: GenerationConfig) -> ProteinType:
if isinstance(input, ESMProtein):
output = self.__generate_protein(input, config)
elif isinstance(input, ESMProteinTensor):
output = self.__generate_protein_tensor(input, config)
else:
- return ESMProteinError(error_msg=f"Unknown input type {type(input)}")
+ return ESMProteinError(
+ error_code=500,
+ error_msg=f"Unknown input type {type(input)}",
+ )
if (
isinstance(output, ESMProtein)
@@ -72,7 +127,7 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
return output
def batch_generate(
- self, inputs: list[ProteinType], configs: list[GenerationConfig]
+ self, inputs: Sequence[ProteinType], configs: Sequence[GenerationConfig]
) -> Sequence[ProteinType]:
"""Forge supports auto-batching. So batch_generate() for the Forge client
is as simple as running a collection of generate() in parallel using asyncio.
@@ -88,10 +143,12 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
results = loop.run_until_complete(_async_generate())
- return [
- r if not isinstance(r, BaseException) else ESMProteinError(str(r))
- for r in results
- ]
+ def _capture_exception(r):
+ if isinstance(r, BaseException) and not isinstance(r, ESMProteinError):
+ return ESMProteinError(500, str(r))
+ return r
+
+ return [_capture_exception(r) for r in results]
def __generate_protein(
self,
@@ -119,8 +176,8 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
}
try:
data = self._post("generate", request, input.potential_sequence_of_concern)
- except RuntimeError as e:
- return ESMProteinError(error_msg=str(e))
+ except ESMProteinError as e:
+ return e
return ESMProtein(
sequence=data["outputs"]["sequence"],
@@ -166,8 +223,8 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
data = self._post(
"generate_tensor", request, input.potential_sequence_of_concern
)
- except RuntimeError as e:
- return ESMProteinError(error_msg=str(e))
+ except ESMProteinError as e:
+ return e
def _field_to_tensor(field, convert_none_to_nan: bool = False):
if field not in data["outputs"]:
@@ -188,6 +245,7 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
return output
+ @retry_decorator
def forward_and_sample(
self, input: ESMProteinTensor, sampling_configuration: SamplingConfig
) -> ForwardAndSampleOutput | ESMProteinError:
@@ -237,8 +295,8 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
data = self._post(
"forward_and_sample", request, input.potential_sequence_of_concern
)
- except RuntimeError as e:
- return ESMProteinError(error_msg=str(e))
+ except ESMProteinError as e:
+ return e
def get(k, field):
if data[k] is None:
@@ -286,6 +344,7 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
)
return output
+ @retry_decorator
def encode(self, input: ESMProtein) -> ESMProteinTensor | ESMProteinError:
tracks = {}
tracks["sequence"] = input.sequence
@@ -299,8 +358,8 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
try:
data = self._post("encode", request, input.potential_sequence_of_concern)
- except RuntimeError as e:
- return ESMProteinError(error_msg=str(e))
+ except ESMProteinError as e:
+ return e
return ESMProteinTensor(
sequence=maybe_tensor(data["outputs"]["sequence"]),
@@ -314,6 +373,7 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
residue_annotations=maybe_tensor(data["outputs"]["residue_annotation"]),
)
+ @retry_decorator
def decode(
self,
input: ESMProteinTensor,
@@ -334,8 +394,8 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
try:
data = self._post("decode", request, input.potential_sequence_of_concern)
- except RuntimeError as e:
- return ESMProteinError(error_msg=str(e))
+ except ESMProteinError as e:
+ return e
return ESMProtein(
sequence=data["outputs"]["sequence"],
@@ -351,9 +411,10 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
ptm=maybe_tensor(data["outputs"]["ptm"]),
)
+ @retry_decorator
def logits(
self, input: ESMProteinTensor, config: LogitsConfig = LogitsConfig()
- ) -> LogitsOutput:
+ ) -> LogitsOutput | ESMProteinError:
# Note: using raw model forwards is discouraged because of the byte size
# of the logits.
# Please use forward_and_sample instead.
@@ -381,7 +442,11 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
"inputs": req,
"logits_config": logits_config,
}
- data = self._post("logits", request, input.potential_sequence_of_concern)
+
+ try:
+ data = self._post("logits", request, input.potential_sequence_of_concern)
+ except ESMProteinError as e:
+ return e
def _maybe_logits(track: str):
if "logits" in data and track in data["logits"]:
@@ -413,7 +478,10 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
)
if not response.ok:
- raise RuntimeError(f"Failure in {endpoint}: {response.text}")
+ raise ESMProteinError(
+ error_code=response.status_code,
+ error_msg=f"Failure in {endpoint}: {response.text}",
+ )
data = response.json()
# Nextjs puts outputs dict under "data" key.
@@ -422,3 +490,9 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
data = data["data"]
return data
+
+ @property
+ def raw_model(self):
+ raise NotImplementedError(
+ f"Can not get underlying remote model {self.model} from a Forge client."
+ )
diff --git a/esm/tokenization/__init__.py b/esm/tokenization/__init__.py
index d22c0de..70b9885 100644
--- a/esm/tokenization/__init__.py
+++ b/esm/tokenization/__init__.py
@@ -1,7 +1,10 @@
from dataclasses import dataclass
from typing import Protocol
-from esm.utils.constants.models import ESM3_OPEN_SMALL
+from esm.utils.constants.models import (
+ ESM3_OPEN_SMALL,
+ normalize_model_name,
+)
from .function_tokenizer import InterProQuantizedTokenizer
from .residue_tokenizer import ResidueAnnotationsTokenizer
@@ -32,7 +35,7 @@ class TokenizerCollection:
def get_model_tokenizers(model: str = ESM3_OPEN_SMALL) -> TokenizerCollection:
- if model == ESM3_OPEN_SMALL:
+ if normalize_model_name(model) == ESM3_OPEN_SMALL:
return TokenizerCollection(
sequence=EsmSequenceTokenizer(),
structure=StructureTokenizer(),
diff --git a/esm/utils/constants/models.py b/esm/utils/constants/models.py
index c72b922..3ffab24 100644
--- a/esm/utils/constants/models.py
+++ b/esm/utils/constants/models.py
@@ -1,5 +1,23 @@
# Model names
ESM3_OPEN_SMALL = "esm3_sm_open_v1"
+ESM3_OPEN_SMALL_ALIAS_1 = "esm3-small-open-2024-03"
+ESM3_OPEN_SMALL_ALIAS_2 = "esm3-sm-open-v1"
+ESM3_OPEN_SMALL_ALIAS_3 = "esm3-open"
ESM3_STRUCTURE_ENCODER_V0 = "esm3_structure_encoder_v0"
ESM3_STRUCTURE_DECODER_V0 = "esm3_structure_decoder_v0"
ESM3_FUNCTION_DECODER_V0 = "esm3_function_decoder_v0"
+
+
+def model_is_locally_supported(x: str):
+ return x in {
+ ESM3_OPEN_SMALL,
+ ESM3_OPEN_SMALL_ALIAS_1,
+ ESM3_OPEN_SMALL_ALIAS_2,
+ ESM3_OPEN_SMALL_ALIAS_3,
+ }
+
+
+def normalize_model_name(x: str):
+ if x in {ESM3_OPEN_SMALL_ALIAS_1, ESM3_OPEN_SMALL_ALIAS_2, ESM3_OPEN_SMALL_ALIAS_3}:
+ return ESM3_OPEN_SMALL
+ return x
diff --git a/esm/utils/generation.py b/esm/utils/generation.py
index d53d805..ec64fe0 100644
--- a/esm/utils/generation.py
+++ b/esm/utils/generation.py
@@ -299,6 +299,23 @@ def _get_iterative_sampling_mask_for_prompt_and_step(
return where_to_sample
+def _get_non_special_tokens(
+ protein: ESMProteinTensor, tokenizers: TokenizerCollectionProtocol
+) -> int:
+ if protein.sequence is None:
+ # There is no sequence to infer the number of tokens to decode.
+ # So we assume the entire sequence minus bos and eos are for decoding.
+ return len(protein) - 2
+
+ mask = torch.ones_like(protein.sequence)
+ for special_token in tokenizers.sequence.special_token_ids:
+ if special_token == tokenizers.sequence.mask_token_id:
+ continue # MASK tokens need to be sampled.
+ mask[protein.sequence == special_token] = 0
+
+ return int(torch.sum(mask).item())
+
+
def iterative_sampling_tokens(
client: ESM3InferenceClient,
input_tokens: list[ESMProteinTensor],
@@ -320,12 +337,12 @@ def iterative_sampling_tokens(
sequence_lengths = [len(tokens) for tokens in sampled_tokens]
# Figure out the number of tokens to be sampled for each prompt.
total_to_sample = []
- for protein, seq_len, config in zip(sampled_tokens, sequence_lengths, configs):
+ for protein, config in zip(sampled_tokens, configs):
track = config.track
if getattr(protein, track) is None:
# We need to sample the entire track.
- total_to_sample.append(seq_len - 2)
+ total_to_sample.append(_get_non_special_tokens(protein, tokenizers))
continue
masked = _get_masked_positions(
@@ -368,7 +385,8 @@ def iterative_sampling_tokens(
if config.track in ["coordinates", "residue_annotations"]:
errors[i] = ESMProteinError(
- error_msg=f"Iterative sampling {config.track} is not supported."
+ error_code=500,
+ error_msg=f"Iterative sampling {config.track} is not supported.",
)
continue
@@ -427,7 +445,7 @@ def iterative_sampling_tokens(
tokenizers,
)
except ValueError as e:
- errors[i] = ESMProteinError(error_msg=str(e))
+ errors[i] = ESMProteinError(error_code=500, error_msg=str(e))
continue
where_to_sample.to(input_tokens[0].device)
@@ -547,7 +565,7 @@ def _sample_per_prompt(
valid_ids = (
set(tokenizers.sasa.all_token_ids)
- - set(tokenizer.special_token_ids)
+ - set(tokenizers.sasa.special_token_ids)
- set(config.invalid_ids)
)
sasa_logits = logits_output.logits.sasa
@@ -568,7 +586,8 @@ def _sample_per_prompt(
# Sample function and residue annotations separately
config = getattr(sampling_config, "function")
- if config is None:
+ function_logits = getattr(logits_output.logits, "function")
+ if config is None or function_logits is None:
tokens_dir["function"] = maybe_clone(getattr(protein, "function"))
tokens_dir["residue_annotations"] = maybe_clone(
getattr(protein, "residue_annotations")
@@ -580,7 +599,7 @@ def _sample_per_prompt(
sampling_metadata = _sample_function_track(
tokenizers.function,
tokens=getattr(protein, "function"),
- logits=getattr(logits_output.logits, "function"),
+ logits=function_logits,
sampling_track_config=config,
)
tokens_dir["function"] = sampling_metadata.pop("sampled_tokens") # (L, D)
diff --git a/esm/utils/generation_test.py b/esm/utils/generation_test.py
new file mode 100644
index 0000000..3435ed5
--- /dev/null
+++ b/esm/utils/generation_test.py
@@ -0,0 +1,116 @@
+import pytest
+import torch
+
+from esm.sdk.api import (
+ ESMProtein,
+ ESMProteinError,
+ ESMProteinTensor,
+ GenerationConfig,
+)
+from evolutionaryscale.utils.env import ModelName
+from evolutionaryscale.utils.remote_inference.api_v1 import (
+ ESM3RemoteModelInferenceClient,
+)
+from projects.forge.fastapi.utils.model import _load_esm3
+
+
+@pytest.fixture()
+def esm3_remote_inference_client():
+ model = _load_esm3(ModelName.ESM3_TINY_DEV, distributed_model=False)
+ client = ESM3RemoteModelInferenceClient(
+ model,
+ tokenizers=model.tokenizers,
+ device=torch.device("cuda"),
+ enable_batched_runner=False,
+ )
+ return client
+
+
+@pytest.mark.gpu
+def test_chain_break_tokens(esm3_remote_inference_client):
+ tokenizer = esm3_remote_inference_client.tokenizers.sequence
+ # 3 separate chains with 2 chainbreak tokens.
+ sequence_with_chain_breaks = torch.tensor(
+ [
+ tokenizer.bos_token_id,
+ 20,
+ 20,
+ 20,
+ 20,
+ tokenizer.chain_break_token_id,
+ 21,
+ 21,
+ 21,
+ tokenizer.chain_break_token_id,
+ 22,
+ 22,
+ 22,
+ tokenizer.eos_token_id,
+ ]
+ )
+ protein = esm3_remote_inference_client.generate(
+ ESMProteinTensor(sequence=sequence_with_chain_breaks),
+ # There are 10 tokens that actually need to be sampled.
+ GenerationConfig(track="structure", num_steps=10),
+ )
+
+ assert isinstance(protein, ESMProteinTensor)
+ assert protein.structure is not None
+
+
+@pytest.mark.gpu
+def test_num_decoding_steps_more_than_mask_tokens_fails(esm3_remote_inference_client):
+ protein = esm3_remote_inference_client.generate(
+ ESMProtein(sequence="CDEFG"), # sequence of 5.
+ GenerationConfig(track="structure", num_steps=10), # use 10 decoding steps.
+ )
+ # Can't specify more decoding steps than masks available.
+ assert isinstance(protein, ESMProteinError)
+ assert protein.error_code == 500
+
+
+@pytest.mark.gpu
+def test_encode_chainbreak_token(esm3_remote_inference_client):
+ protein = esm3_remote_inference_client.encode(
+ ESMProtein(sequence="MSTNP|KPQKK"),
+ )
+ # Can't specify more decoding steps than masks available.
+ assert isinstance(protein, ESMProteinTensor)
+ assert protein.sequence is not None
+ assert (
+ protein.sequence[6]
+ == esm3_remote_inference_client.tokenizers.sequence.chain_break_token_id
+ )
+
+
+@pytest.mark.gpu
+def test_generation_with_chainbreak_token(esm3_remote_inference_client):
+ chainbreak_sequence = torch.tensor(
+ [
+ esm3_remote_inference_client.tokenizers.sequence.bos_token_id,
+ 20,
+ 8,
+ 11,
+ 17,
+ 14,
+ esm3_remote_inference_client.tokenizers.sequence.chain_break_token_id,
+ 15,
+ 14,
+ 16,
+ 15,
+ 15,
+ esm3_remote_inference_client.tokenizers.sequence.eos_token_id,
+ ]
+ )
+
+ protein = esm3_remote_inference_client.generate(
+ ESMProteinTensor(sequence=chainbreak_sequence),
+ GenerationConfig(track="structure", num_steps=1),
+ )
+ # Can't specify more decoding steps than masks available.
+ assert isinstance(protein, ESMProteinTensor)
+ assert protein.structure is not None
+ assert (
+ protein.structure[6]
+ == esm3_remote_inference_client.tokenizers.structure.chain_break_token_id
+ )
diff --git a/esm/utils/sampling.py b/esm/utils/sampling.py
index 3db37fd..097e418 100644
--- a/esm/utils/sampling.py
+++ b/esm/utils/sampling.py
@@ -163,6 +163,10 @@ def sample_logits(
logits is shape (..., vocab_size)
temperature is broadcastable to (...)
"""
+ if len(valid_ids) == 0:
+ raise ValueError(
+ "Can not sample logits if there are no valid ids to sample from."
+ )
if top_p < 1.0:
logits = top_p_logits(logits, top_p=top_p)
@@ -181,7 +185,7 @@ def sample_logits(
if torch.all(temperature == 0):
ids = logits.argmax(-1)
- return ids
+ return ids.reshape(*batch_dims)
assert not torch.any(temperature == 0), "Partial temperature 0 not supported."
diff --git a/esm/utils/sampling_test.py b/esm/utils/sampling_test.py
new file mode 100644
index 0000000..5abfa4e
--- /dev/null
+++ b/esm/utils/sampling_test.py
@@ -0,0 +1,40 @@
+import pytest
+import torch
+
+from esm.utils.sampling import sample_logits
+
+
+def test_sample_logits():
+ # batched input. temperature != 0.0.
+ sampled = sample_logits(
+ logits=torch.randn((64, 8, 4096)), temperature=0.8, valid_ids=list(range(4096))
+ )
+ assert sampled.shape == (64, 8)
+
+ # batched input. temperature == 0.0.
+ sampled = sample_logits(
+ logits=torch.randn((64, 8, 4096)), temperature=0.0, valid_ids=list(range(4096))
+ )
+ assert sampled.shape == (64, 8)
+
+ # non-batched input. temperature != 0.0.
+ sampled = sample_logits(
+ logits=torch.randn((8, 4096)), temperature=0.8, valid_ids=list(range(4096))
+ )
+ assert sampled.shape == (8,)
+
+ # non-batched input. temperature == 0.0.
+ sampled = sample_logits(
+ logits=torch.randn((8, 4096)), temperature=0.0, valid_ids=list(range(4096))
+ )
+ assert sampled.shape == (8,)
+
+ with pytest.raises(ValueError):
+ sampled = sample_logits(
+ logits=torch.randn((8, 4096)),
+ temperature=0.0,
+ valid_ids=[],
+ )
+
+
+test_sample_logits()
diff --git a/esm/widgets/views/esm3_generation_launcher.py b/esm/widgets/views/esm3_generation_launcher.py
index f6a8feb..6328446 100644
--- a/esm/widgets/views/esm3_generation_launcher.py
+++ b/esm/widgets/views/esm3_generation_launcher.py
@@ -157,7 +157,10 @@ def create_esm3_generation_launcher(
model=model_name.value, token=forge_token
)
elif isinstance(client, ESM3):
- if model_name.value != models.ESM3_OPEN_SMALL:
+ if (
+ models.normalize_model_name(model_name.value)
+ != models.ESM3_OPEN_SMALL
+ ):
raise ValueError(
f"Model name {model_name.value} does not match the client model {models.ESM3_OPEN_SMALL}"
)
diff --git a/esm/widgets/views/login.py b/esm/widgets/views/login.py
index d5cae47..0de4500 100644
--- a/esm/widgets/views/login.py
+++ b/esm/widgets/views/login.py
@@ -92,7 +92,7 @@ def create_login_ui(client_container: ClientInitContainer):
layout={"width": "50%"},
)
forge_model = widgets.Text(
- value="esm3-md-v1",
+ value="esm3-medium-2024-08",
description="Model Name:",
disabled=False,
layout={"width": "50%"},
diff --git a/examples/generate.ipynb b/examples/generate.ipynb
index 357a5bf..ef55ff6 100644
--- a/examples/generate.ipynb
+++ b/examples/generate.ipynb
@@ -1,582 +1,582 @@
{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# ESM3\n",
- "ESM3 is a frontier generative model for biology, able to jointly reason across three fundamental biological properties of proteins: sequence, structure, and function. These three data modalities are represented as tracks of discrete tokens at the input and output of ESM3. You can present the model with a combination of partial inputs across the tracks, and ESM3 will provide output predictions for all the tracks.\n",
- "\n",
- "ESM3 is a generative masked language model. You can prompt it with partial sequence, structure, and function keywords, and iteratively sample masked positions until all positions are unmasked. This iterative sampling is what the `.generate()` function does.\n",
- "\n",
- "\n",
- "\n",
- "The ESM3 architecture is highly scalable due to its transformer backbone and all-to-all reasoning over discrete token sequences. At its largest scale, ESM3 was trained with 1.07e24 FLOPs on 2.78 billion proteins and 771 billion unique tokens, and has 98 billion parameters.\n",
- "Here we present `esm3-open-small`. With 1.4B parameters it is the smallest and fastest model in the family, trained specifically to be open sourced. ESM3-open is available under a non-commercial license."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Imports\n",
- "\n",
- "If you're running in Colab, you probably want to get a GPU runtime first (Runtime > Change runtime type > T4 GPU)."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "%set_env TOKENIZERS_PARALLELISM=false\n",
- "!pip install esm\n",
- "import numpy as np\n",
- "import torch\n",
- "!pip install py3Dmol\n",
- "import py3Dmol\n",
- "import huggingface_hub\n",
- "\n",
- "from esm.utils.structure.protein_chain import ProteinChain\n",
- "from esm.models.esm3 import ESM3\n",
- "from esm.sdk import client\n",
- "from esm.sdk.api import (\n",
- " ESMProtein,\n",
- " GenerationConfig,\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Load `esm-open-small` on GPU"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "huggingface_hub.login() # will prompt you to get an API key and accept the ESM3 license.\n",
- "model = ESM3.from_pretrained(\"esm3_sm_open_v1\", device=torch.device(\"cuda\"))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Alternatively, you could use the Forge API running the model remotely, and use the local `client` to call the API just like you're used to with the model running locally on your GPU:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# from getpass import getpass\n",
- "# token = getpass(\"Token from Forge console: \")\n",
- "# model = client(\n",
- "# model=\"esm3-lg-alpha1\",\n",
- "# url=\"https://forge.evolutionaryscale.ai\",\n",
- "# token=token,\n",
- "# )"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Let's construct a prompt for ESM3, focusing on the task of scaffolding a motif from a natural protein"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "First, we can use the `ProteinChain` class from the `esm` sdk to grab a protein structure from the PDB.\n",
- "We'll work with a human renal (kidney) dipeptidase (a protein that breaks up two amino acids bound together). Renal dipeptidases are of particular interest because they metabolize certain antibiotics."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "pdb_id = \"1ITU\" # PDB ID corresponding to Renal Dipeptidase\n",
- "chain_id = \"A\" # Chain ID corresponding to Renal Dipeptidase in the PDB structure\n",
- "renal_dipep_chain = ProteinChain.from_rcsb(pdb_id, chain_id)\n",
- "# Alternatively, we could have used ProteinChain.from_pdb() to load a protein structure from a local PDB file"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The `ProteinChain` class is a object that makes it easy to work with protein structures. It contains a `sequence` attribute that contains the amino acid sequence of the protein\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "print(renal_dipep_chain.sequence)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "`ProteinChain` also contains an `atom37_positions` numpy array that contains the atomic coordinates of each of the residues in the protein. \n",
- "\n",
- "The shape of the array is `(n_residues, 37, 3)` where `n_residues` is the number of residues in the protein and 37 is the number of possible distinct atoms that may be present across all amino acids (e.g. the first three atoms are the N, C-alpha, and C atoms corresponding to the protein backbone). The 3 corresponds to the x, y, and z coordinates of each atom. The atom37 representation of protein structure allows us to use a single format to conveniently represent all amino acids -- **coordinates are only present for the atoms that are present in the amino acid and `nan` otherwise**."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "print(\"atom37_positions shape: \", renal_dipep_chain.atom37_positions.shape)\n",
- "print(renal_dipep_chain.atom37_positions[:3])"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We can visualize the protein chain using the `py3Dmol` library"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# First we can create a `py3Dmol` view object\n",
- "view = py3Dmol.view(width=500, height=500)\n",
- "# py3Dmol requires the atomic coordinates to be in PDB format, so we convert the `ProteinChain` object to a PDB string\n",
- "pdb_str = renal_dipep_chain.to_pdb_string()\n",
- "# Load the PDB string into the `py3Dmol` view object\n",
- "view.addModel(pdb_str, \"pdb\")\n",
- "# Set the style of the protein chain\n",
- "view.setStyle({\"cartoon\": {\"color\": \"spectrum\"}})\n",
- "# Zoom in on the protein chain\n",
- "view.zoomTo()\n",
- "# Display the protein chain\n",
- "view.show()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Now, let's try to scaffold a motif from this protein using ESM3 -- we'll prompt the model with the sequence and structure of a helix-coil motif from renal dipeptidase and have the model generate a larger scaffold that includes the motif"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "motif_inds = np.arange(123, 146)\n",
- "# `ProteinChain` objects can be indexed like numpy arrays to extract the sequence and atomic coordinates of a subset of residues\n",
- "motif_sequence = renal_dipep_chain[motif_inds].sequence\n",
- "motif_atom37_positions = renal_dipep_chain[motif_inds].atom37_positions\n",
- "print(\"Motif sequence: \", motif_sequence)\n",
- "print(\"Motif atom37_positions shape: \", motif_atom37_positions.shape)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We can also visualize the motif in the original chain using `py3Dmol`. We'll color the original chain in grey and the motif in blue"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "view = py3Dmol.view(width=500, height=500)\n",
- "view.addModel(pdb_str, \"pdb\")\n",
- "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n",
- "motif_res_inds = (motif_inds + 1).tolist() # residue indices are 1-indexed in PDB files, so we add 1 to the indices\n",
- "view.addStyle({\"resi\": motif_res_inds}, {\"cartoon\": {\"color\": \"cyan\"}})\n",
- "view.zoomTo()\n",
- "view.show()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Now, we can use the `ESMProtein` class to construct a prompt that will instruct ESM3 to scaffold the motif"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "prompt_length = 200\n",
- "# First, we can construct a sequence prompt of all masks\n",
- "sequence_prompt = [\"_\"]*prompt_length\n",
- "# Then, we can randomly insert the motif sequence into the prompt (we randomly choose 72 here)\n",
- "sequence_prompt[72:72+len(motif_sequence)] = list(motif_sequence)\n",
- "sequence_prompt = \"\".join(sequence_prompt)\n",
- "print(\"Sequence prompt: \", sequence_prompt)\n",
- "print(\"Length of sequence prompt: \", len(sequence_prompt))\n",
- "\n",
- "# Next, we can construct a structure prompt of all nan coordinates\n",
- "structure_prompt = torch.full((prompt_length, 37, 3), np.nan)\n",
- "# Then, we can insert the motif atomic coordinates into the prompt, starting at index 72\n",
- "structure_prompt[72:72+len(motif_atom37_positions)] = torch.tensor(motif_atom37_positions)\n",
- "print(\"Structure prompt shape: \", structure_prompt.shape)\n",
- "print(\"Indices with structure conditioning: \", torch.where(~torch.isnan(structure_prompt).any(dim=-1).all(dim=-1))[0].tolist())\n",
- "\n",
- "# Finally, we can use the ESMProtein class to compose the sequence and structure prompts into a single prompt that can be passed to ESM3\n",
- "protein_prompt = ESMProtein(sequence=sequence_prompt, coordinates=structure_prompt)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Now, we can use the `generate` method of the model to iteratively sample a protein sequence based on the prompt. Under the hood, the model performs num_steps forward passes and samples a set of tokens at each step until the chosen track being generated is fully unmasked. "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# We'll have to first construct a `GenerationConfig` object that specifies the decoding parameters that we want to use\n",
- "sequence_generation_config = GenerationConfig(\n",
- " track=\"sequence\", # We want ESM3 to generate tokens for the sequence track\n",
- " num_steps=sequence_prompt.count(\"_\") // 2, # We'll use num(mask tokens) // 2 steps to decode the sequence\n",
- " temperature=0.5, # We'll use a temperature of 0.5 to control the randomness of the decoding process\n",
- ")\n",
- "\n",
- "# Now, we can use the `generate` method of the model to decode the sequence\n",
- "sequence_generation = model.generate(protein_prompt, sequence_generation_config)\n",
- "print(\"Sequence Prompt:\\n\\t\", protein_prompt.sequence)\n",
- "print(\"Generated sequence:\\n\\t\", sequence_generation.sequence)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We can also use the `generate` method to predict the structure of the generated sequence by iteratively sampling structure tokens."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "structure_prediction_config = GenerationConfig(\n",
- " track=\"structure\", # We want ESM3 to generate tokens for the structure track\n",
- " num_steps=len(sequence_generation) // 8,\n",
- " temperature=0.7, \n",
- ")\n",
- "structure_prediction_prompt = ESMProtein(sequence=sequence_generation.sequence)\n",
- "structure_prediction = model.generate(structure_prediction_prompt, structure_prediction_config)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Now, we can visualize the generated structure using `py3Dmol`. We'll visualize the generated structure (right, green) alongside the original structure (left, grey) from which the motif was drawn. The motif residues are colored in cyan."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Convert the generated structure to a back into a ProteinChain object\n",
- "structure_prediction_chain = structure_prediction.to_protein_chain()\n",
- "# Align the generated structure to the original structure using the motif residues\n",
- "motif_inds_in_generation = np.arange(72, 72+len(motif_sequence))\n",
- "structure_prediction_chain.align(renal_dipep_chain, mobile_inds=motif_inds_in_generation, target_inds=motif_inds)\n",
- "crmsd = structure_prediction_chain.rmsd(renal_dipep_chain, mobile_inds=motif_inds_in_generation, target_inds=motif_inds)\n",
- "print(\"cRMSD of the motif in the generated structure vs the original structure: \", crmsd)\n",
- "\n",
- "view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))\n",
- "view.addModel(pdb_str, \"pdb\", viewer=(0, 0))\n",
- "view.addModel(structure_prediction_chain.to_pdb_string(), \"pdb\", viewer=(0, 1))\n",
- "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}}, viewer=(0, 0))\n",
- "view.setStyle({\"cartoon\": {\"color\": \"lightgreen\"}}, viewer=(0, 1))\n",
- "view.addStyle({\"resi\": motif_res_inds}, {\"cartoon\": {\"color\": \"cyan\"}}, viewer=(0, 0))\n",
- "view.addStyle({\"resi\": (motif_inds_in_generation+1).tolist()}, {\"cartoon\": {\"color\": \"cyan\"}}, viewer=(0, 1))\n",
- "view.zoomTo()\n",
- "view.show()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Secondary Structure Editing Example: Helix Shortening"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Now, we can try another generation task with ESM3. We'll use the secondary structure track, along with the sequence track, to shorten a helix-coil-helix region (residues 39-111) in a protein structure (colored in blue below)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "helix_shortening_chain = ProteinChain.from_rcsb(\"7XBQ\", \"A\")\n",
- "view = py3Dmol.view(width=500, height=500)\n",
- "view.addModel(helix_shortening_chain.to_pdb_string(), \"pdb\")\n",
- "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n",
- "helix_region = np.arange(38, 111) # zero-indexed\n",
- "view.addStyle({\"resi\": (helix_region + 1).tolist()}, {\"cartoon\": {\"color\":\"lightblue\"}})\n",
- "view.zoomTo()\n",
- "view.show()\n",
- "helix_shortening_ss8 = \"CCCSHHHHHHHHHHHTTCHHHHHHHHHHHHHTCSSCCCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHTTCHHHHHHHHHHHHHHHHHHHHHHHHHHHHIIIIIGGGCCSHHHHHHHHHHHHHHHHHHHHHCCHHHHHHHHHHHHHHHHHHHHHHHHHSCTTCHHHHHHHHHHHHHIIIIICCHHHHHHHHHHHHHHHHTTCTTCCSSHHHHHHHHHHHHHHHHHHHC\"\n",
- "print(\"Secondary structure of protein: (H: Alpha Helix, E: Beta Strand, C: Coil) \\n\\t\", helix_shortening_ss8)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The helix-coil-helix region in the original protein is 73 residues long. We will try to shorten it to 45 residues by prompting the model with partial sequence and secondary structure"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "shortened_region_length = 45\n",
- "\n",
- "# We'll construct a sequence prompt that masks the (shortened) helix-coil-helix region, but leaves the flanking regions unmasked\n",
- "sequence_prompt = helix_shortening_chain.sequence[:helix_region[0]] + \"_\" * shortened_region_length + helix_shortening_chain.sequence[helix_region[-1] + 1:]\n",
- "print(\"Sequence prompt:\\n\\t\", sequence_prompt)\n",
- "\n",
- "# We'll construct a secondary structure prompt that retains the secondary structure of the flanking regions, and shortens the lengths of helices in the helix-coil-helix region\n",
- "ss8_prompt = helix_shortening_ss8[:helix_region[0]] + (((shortened_region_length - 3) // 2) * \"H\" + \"C\"*3 + ((shortened_region_length - 3) // 2) * \"H\") + helix_shortening_ss8[helix_region[-1] + 1:]\n",
- "print(\"SS8 prompt:\\n\\t\", ss8_prompt)\n",
- "print(\"Proposed SS8 for shortened helix-coil-helix region:\\n\\t\", \" \"*helix_region[0] + ss8_prompt[helix_region[0]:helix_region[0]+45])\n",
- "\n",
- "print(\"\")\n",
- "print(\"Original sequence:\\n\\t\", helix_shortening_chain.sequence)\n",
- "print(\"Original SS8:\\n\\t\", helix_shortening_ss8)\n",
- "print(\"Original SS8 for helix-coil-helix region:\\n\\t\", \" \"*helix_region[0] + helix_shortening_ss8[helix_region[0]:helix_region[-1]+1])\n",
- "\n",
- "\n",
- "# We can again use the ESMProtein class to compose the sequence and secondary structure prompts into a single prompt that can be passed to ESM3\n",
- "protein_prompt = ESMProtein(sequence=sequence_prompt, secondary_structure=ss8_prompt)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We can again use the `generate` method of the model to iteratively decode a protein sequence based on the prompt"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "print(\"Generating protein sequence...\")\n",
- "sequence_generation = model.generate(protein_prompt, GenerationConfig(track=\"sequence\", num_steps=protein_prompt.sequence.count(\"_\") // 2, temperature=0.5))\n",
- "print(\"Folding protein...\")\n",
- "structure_prediction = model.generate(ESMProtein(sequence=sequence_generation.sequence), GenerationConfig(track=\"structure\", num_steps=len(protein_prompt) // 4, temperature=0))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Now, we can visualize the generated structure using `py3Dmol`. We'll visualize the generated structure (right) alongside the original structure (left) from which the motif was drawn. The helix-coil-helix region in the original structure is colored in blue and the shortened region in the generated structure is colored in pink."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "predicted_chain = structure_prediction.to_protein_chain()\n",
- "predicted_chain = predicted_chain.align(helix_shortening_chain, mobile_inds=np.arange(len(predicted_chain) - 120, len(predicted_chain)), target_inds=np.arange(len(helix_shortening_chain) - 120, len(helix_shortening_chain)))\n",
- "view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))\n",
- "view.addModel(helix_shortening_chain.to_pdb_string(), \"pdb\", viewer=(0, 0))\n",
- "view.addModel(predicted_chain.to_pdb_string(), \"pdb\", viewer=(0, 1))\n",
- "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n",
- "view.addStyle({\"resi\": (helix_region + 1).tolist()}, {\"cartoon\": {\"color\":\"lightblue\"}},viewer=(0, 0))\n",
- "view.addStyle({\"resi\": (np.arange(helix_region[0], helix_region[0] + 45) + 1).tolist()}, {\"cartoon\": {\"color\":\"pink\"}},viewer=(0, 1))\n",
- "view.zoomTo()\n",
- "view.show()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# SASA Editing Example: Exposing a buried helix"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Let's grab 1LBS from the PDB and visualize it using `py3Dmol`. 1LBS has an alternating alpha-beta sandwich fold, with a buried helix in the center, highlighted in red"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "lipase_chain = ProteinChain.from_rcsb(\"1LBS\", \"A\")\n",
- "span_start = 105\n",
- "span_end = 116\n",
- "view = py3Dmol.view(width=500, height=500)\n",
- "view.addModel(lipase_chain.to_pdb_string(), \"pdb\")\n",
- "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n",
- "view.addStyle({\"resi\": (np.arange(span_start, span_end) + 1).tolist()}, {\"cartoon\": {\"color\":\"red\"}})\n",
- "view.zoomTo()\n",
- "view.show()\n",
- "lipase_ss8 = \"CCSSCCCCSSCHHHHHHTEEETTBBTTBCSSEEEEECCTTCCHHHHHTTTHHHHHHHTTCEEEEECCTTTTCSCHHHHHHHHHHHHHHHHHHTTSCCEEEEEETHHHHHHHHHHHHCGGGGGTEEEEEEESCCTTCBGGGHHHHHTTCBCHHHHHTBTTCHHHHHHHHTTTTBCSSCEEEEECTTCSSSCCCCSSSTTSTTCCBTSEEEEHHHHHCTTCCCCSHHHHHBHHHHHHHHHHHHCTTSSCCGGGCCSTTCCCSBCTTSCHHHHHHHHSTHHHHHHHHHHSCCBSSCCCCCGGGGGGSTTCEETTEECCC\""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We can construct a multimodal prompt for ESM3 to instruct it to expose the buried helix as follows:\n",
- "1. Prompt with the **structure** of the buried helix highlighted in red -- this will prompt ESM3 to generate a protein that contains that same helix\n",
- "2. Prompt with high **SASA** values for the residues in the buried helix -- this will prompt ESM3 to expose the helix to the surface of the protein"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "structure_prompt = torch.full((len(lipase_chain), 37, 3), torch.nan)\n",
- "structure_prompt[span_start:span_end] = torch.tensor(lipase_chain[span_start:span_end].atom37_positions, dtype=torch.float32) \n",
- "\n",
- "sasa_prompt = [None]*len(lipase_chain)\n",
- "sasa_prompt[span_start:span_end] = [40.0]*(span_end - span_start)\n",
- "\n",
- "print(\"SASA prompt (just for buried region): \", sasa_prompt[span_start:span_end])\n",
- "\n",
- "protein_prompt = ESMProtein(sequence=\"_\"*len(lipase_chain), coordinates=structure_prompt, sasa=sasa_prompt)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "This is a more difficult task, so you may need to sample more generations from ESM before you find a solution. We'll sample 32 here and sort by the generations with the highest predicted TM-score (pTM) by ESM3. "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "generated_proteins = []\n",
- "N_SAMPLES = 16\n",
- "for i in range(N_SAMPLES):\n",
- " print(\"Generating protein sequence...\")\n",
- " sequence_generation = model.generate(protein_prompt, GenerationConfig(track=\"sequence\", num_steps=len(protein_prompt) // 8, temperature=0.7))\n",
- " print(\"Folding protein...\")\n",
- " structure_prediction = model.generate(ESMProtein(sequence=sequence_generation.sequence), GenerationConfig(track=\"structure\", num_steps=len(protein_prompt) // 32))\n",
- " generated_proteins.append(structure_prediction)\n",
- "\n",
- "# Sort generations by ptm\n",
- "generated_proteins = sorted(generated_proteins, key=lambda x: x.ptm.item(), reverse=True)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Let's visualize the top 4 generations by pTM, alongside with the original protein (on the left)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "N_SAMPLES_TO_SHOW = 4\n",
- "view = py3Dmol.view(width=1000, height=500, viewergrid=(1, N_SAMPLES_TO_SHOW+1))\n",
- "view.addModel(lipase_chain.to_pdb_string(), \"pdb\", viewer=(0, 0))\n",
- "for i in range(N_SAMPLES_TO_SHOW):\n",
- " print(\"PTM of generated protein {}: {:.2f}\".format(i+1, generated_proteins[i].ptm.item()))\n",
- " view.addModel(generated_proteins[i].to_protein_chain().to_pdb_string(), \"pdb\", viewer=(0, i+1))\n",
- "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n",
- "view.addStyle({\"resi\": (np.arange(span_start, span_end) + 1).tolist()}, {\"cartoon\": {\"color\": \"red\"}})\n",
- "view.zoomTo()\n",
- "view.show()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3 (ipykernel)",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.10.0"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# ESM3\n",
+ "ESM3 is a frontier generative model for biology, able to jointly reason across three fundamental biological properties of proteins: sequence, structure, and function. These three data modalities are represented as tracks of discrete tokens at the input and output of ESM3. You can present the model with a combination of partial inputs across the tracks, and ESM3 will provide output predictions for all the tracks.\n",
+ "\n",
+ "ESM3 is a generative masked language model. You can prompt it with partial sequence, structure, and function keywords, and iteratively sample masked positions until all positions are unmasked. This iterative sampling is what the `.generate()` function does.\n",
+ "\n",
+ "\n",
+ "\n",
+ "The ESM3 architecture is highly scalable due to its transformer backbone and all-to-all reasoning over discrete token sequences. At its largest scale, ESM3 was trained with 1.07e24 FLOPs on 2.78 billion proteins and 771 billion unique tokens, and has 98 billion parameters.\n",
+ "Here we present `esm3-open-small`. With 1.4B parameters it is the smallest and fastest model in the family, trained specifically to be open sourced. ESM3-open is available under a non-commercial license."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Imports\n",
+ "\n",
+ "If you're running in Colab, you probably want to get a GPU runtime first (Runtime > Change runtime type > T4 GPU)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%set_env TOKENIZERS_PARALLELISM=false\n",
+ "!pip install esm\n",
+ "import numpy as np\n",
+ "import torch\n",
+ "!pip install py3Dmol\n",
+ "import py3Dmol\n",
+ "import huggingface_hub\n",
+ "\n",
+ "from esm.utils.structure.protein_chain import ProteinChain\n",
+ "from esm.models.esm3 import ESM3\n",
+ "from esm.sdk import client\n",
+ "from esm.sdk.api import (\n",
+ " ESMProtein,\n",
+ " GenerationConfig,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Load `esm-open-small` on GPU"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "huggingface_hub.login() # will prompt you to get an API key and accept the ESM3 license.\n",
+ "model = ESM3.from_pretrained(\"esm3_sm_open_v1\", device=torch.device(\"cuda\"))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Alternatively, you could use the Forge API running the model remotely, and use the local `client` to call the API just like you're used to with the model running locally on your GPU:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# from getpass import getpass\n",
+ "# token = getpass(\"Token from Forge console: \")\n",
+ "# model = client(\n",
+ "# model=\"esm3-large-2024-03\",\n",
+ "# url=\"https://forge.evolutionaryscale.ai\",\n",
+ "# token=token,\n",
+ "# )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Let's construct a prompt for ESM3, focusing on the task of scaffolding a motif from a natural protein"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "First, we can use the `ProteinChain` class from the `esm` sdk to grab a protein structure from the PDB.\n",
+ "We'll work with a human renal (kidney) dipeptidase (a protein that breaks up two amino acids bound together). Renal dipeptidases are of particular interest because they metabolize certain antibiotics."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pdb_id = \"1ITU\" # PDB ID corresponding to Renal Dipeptidase\n",
+ "chain_id = \"A\" # Chain ID corresponding to Renal Dipeptidase in the PDB structure\n",
+ "renal_dipep_chain = ProteinChain.from_rcsb(pdb_id, chain_id)\n",
+ "# Alternatively, we could have used ProteinChain.from_pdb() to load a protein structure from a local PDB file"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The `ProteinChain` class is a object that makes it easy to work with protein structures. It contains a `sequence` attribute that contains the amino acid sequence of the protein\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(renal_dipep_chain.sequence)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "`ProteinChain` also contains an `atom37_positions` numpy array that contains the atomic coordinates of each of the residues in the protein. \n",
+ "\n",
+ "The shape of the array is `(n_residues, 37, 3)` where `n_residues` is the number of residues in the protein and 37 is the number of possible distinct atoms that may be present across all amino acids (e.g. the first three atoms are the N, C-alpha, and C atoms corresponding to the protein backbone). The 3 corresponds to the x, y, and z coordinates of each atom. The atom37 representation of protein structure allows us to use a single format to conveniently represent all amino acids -- **coordinates are only present for the atoms that are present in the amino acid and `nan` otherwise**."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(\"atom37_positions shape: \", renal_dipep_chain.atom37_positions.shape)\n",
+ "print(renal_dipep_chain.atom37_positions[:3])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We can visualize the protein chain using the `py3Dmol` library"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# First we can create a `py3Dmol` view object\n",
+ "view = py3Dmol.view(width=500, height=500)\n",
+ "# py3Dmol requires the atomic coordinates to be in PDB format, so we convert the `ProteinChain` object to a PDB string\n",
+ "pdb_str = renal_dipep_chain.to_pdb_string()\n",
+ "# Load the PDB string into the `py3Dmol` view object\n",
+ "view.addModel(pdb_str, \"pdb\")\n",
+ "# Set the style of the protein chain\n",
+ "view.setStyle({\"cartoon\": {\"color\": \"spectrum\"}})\n",
+ "# Zoom in on the protein chain\n",
+ "view.zoomTo()\n",
+ "# Display the protein chain\n",
+ "view.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now, let's try to scaffold a motif from this protein using ESM3 -- we'll prompt the model with the sequence and structure of a helix-coil motif from renal dipeptidase and have the model generate a larger scaffold that includes the motif"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "motif_inds = np.arange(123, 146)\n",
+ "# `ProteinChain` objects can be indexed like numpy arrays to extract the sequence and atomic coordinates of a subset of residues\n",
+ "motif_sequence = renal_dipep_chain[motif_inds].sequence\n",
+ "motif_atom37_positions = renal_dipep_chain[motif_inds].atom37_positions\n",
+ "print(\"Motif sequence: \", motif_sequence)\n",
+ "print(\"Motif atom37_positions shape: \", motif_atom37_positions.shape)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We can also visualize the motif in the original chain using `py3Dmol`. We'll color the original chain in grey and the motif in blue"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "view = py3Dmol.view(width=500, height=500)\n",
+ "view.addModel(pdb_str, \"pdb\")\n",
+ "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n",
+ "motif_res_inds = (motif_inds + 1).tolist() # residue indices are 1-indexed in PDB files, so we add 1 to the indices\n",
+ "view.addStyle({\"resi\": motif_res_inds}, {\"cartoon\": {\"color\": \"cyan\"}})\n",
+ "view.zoomTo()\n",
+ "view.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now, we can use the `ESMProtein` class to construct a prompt that will instruct ESM3 to scaffold the motif"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "prompt_length = 200\n",
+ "# First, we can construct a sequence prompt of all masks\n",
+ "sequence_prompt = [\"_\"]*prompt_length\n",
+ "# Then, we can randomly insert the motif sequence into the prompt (we randomly choose 72 here)\n",
+ "sequence_prompt[72:72+len(motif_sequence)] = list(motif_sequence)\n",
+ "sequence_prompt = \"\".join(sequence_prompt)\n",
+ "print(\"Sequence prompt: \", sequence_prompt)\n",
+ "print(\"Length of sequence prompt: \", len(sequence_prompt))\n",
+ "\n",
+ "# Next, we can construct a structure prompt of all nan coordinates\n",
+ "structure_prompt = torch.full((prompt_length, 37, 3), np.nan)\n",
+ "# Then, we can insert the motif atomic coordinates into the prompt, starting at index 72\n",
+ "structure_prompt[72:72+len(motif_atom37_positions)] = torch.tensor(motif_atom37_positions)\n",
+ "print(\"Structure prompt shape: \", structure_prompt.shape)\n",
+ "print(\"Indices with structure conditioning: \", torch.where(~torch.isnan(structure_prompt).any(dim=-1).all(dim=-1))[0].tolist())\n",
+ "\n",
+ "# Finally, we can use the ESMProtein class to compose the sequence and structure prompts into a single prompt that can be passed to ESM3\n",
+ "protein_prompt = ESMProtein(sequence=sequence_prompt, coordinates=structure_prompt)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now, we can use the `generate` method of the model to iteratively sample a protein sequence based on the prompt. Under the hood, the model performs num_steps forward passes and samples a set of tokens at each step until the chosen track being generated is fully unmasked. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# We'll have to first construct a `GenerationConfig` object that specifies the decoding parameters that we want to use\n",
+ "sequence_generation_config = GenerationConfig(\n",
+ " track=\"sequence\", # We want ESM3 to generate tokens for the sequence track\n",
+ " num_steps=sequence_prompt.count(\"_\") // 2, # We'll use num(mask tokens) // 2 steps to decode the sequence\n",
+ " temperature=0.5, # We'll use a temperature of 0.5 to control the randomness of the decoding process\n",
+ ")\n",
+ "\n",
+ "# Now, we can use the `generate` method of the model to decode the sequence\n",
+ "sequence_generation = model.generate(protein_prompt, sequence_generation_config)\n",
+ "print(\"Sequence Prompt:\\n\\t\", protein_prompt.sequence)\n",
+ "print(\"Generated sequence:\\n\\t\", sequence_generation.sequence)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We can also use the `generate` method to predict the structure of the generated sequence by iteratively sampling structure tokens."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "structure_prediction_config = GenerationConfig(\n",
+ " track=\"structure\", # We want ESM3 to generate tokens for the structure track\n",
+ " num_steps=len(sequence_generation) // 8,\n",
+ " temperature=0.7, \n",
+ ")\n",
+ "structure_prediction_prompt = ESMProtein(sequence=sequence_generation.sequence)\n",
+ "structure_prediction = model.generate(structure_prediction_prompt, structure_prediction_config)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now, we can visualize the generated structure using `py3Dmol`. We'll visualize the generated structure (right, green) alongside the original structure (left, grey) from which the motif was drawn. The motif residues are colored in cyan."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Convert the generated structure to a back into a ProteinChain object\n",
+ "structure_prediction_chain = structure_prediction.to_protein_chain()\n",
+ "# Align the generated structure to the original structure using the motif residues\n",
+ "motif_inds_in_generation = np.arange(72, 72+len(motif_sequence))\n",
+ "structure_prediction_chain.align(renal_dipep_chain, mobile_inds=motif_inds_in_generation, target_inds=motif_inds)\n",
+ "crmsd = structure_prediction_chain.rmsd(renal_dipep_chain, mobile_inds=motif_inds_in_generation, target_inds=motif_inds)\n",
+ "print(\"cRMSD of the motif in the generated structure vs the original structure: \", crmsd)\n",
+ "\n",
+ "view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))\n",
+ "view.addModel(pdb_str, \"pdb\", viewer=(0, 0))\n",
+ "view.addModel(structure_prediction_chain.to_pdb_string(), \"pdb\", viewer=(0, 1))\n",
+ "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}}, viewer=(0, 0))\n",
+ "view.setStyle({\"cartoon\": {\"color\": \"lightgreen\"}}, viewer=(0, 1))\n",
+ "view.addStyle({\"resi\": motif_res_inds}, {\"cartoon\": {\"color\": \"cyan\"}}, viewer=(0, 0))\n",
+ "view.addStyle({\"resi\": (motif_inds_in_generation+1).tolist()}, {\"cartoon\": {\"color\": \"cyan\"}}, viewer=(0, 1))\n",
+ "view.zoomTo()\n",
+ "view.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Secondary Structure Editing Example: Helix Shortening"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now, we can try another generation task with ESM3. We'll use the secondary structure track, along with the sequence track, to shorten a helix-coil-helix region (residues 39-111) in a protein structure (colored in blue below)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "helix_shortening_chain = ProteinChain.from_rcsb(\"7XBQ\", \"A\")\n",
+ "view = py3Dmol.view(width=500, height=500)\n",
+ "view.addModel(helix_shortening_chain.to_pdb_string(), \"pdb\")\n",
+ "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n",
+ "helix_region = np.arange(38, 111) # zero-indexed\n",
+ "view.addStyle({\"resi\": (helix_region + 1).tolist()}, {\"cartoon\": {\"color\":\"lightblue\"}})\n",
+ "view.zoomTo()\n",
+ "view.show()\n",
+ "helix_shortening_ss8 = \"CCCSHHHHHHHHHHHTTCHHHHHHHHHHHHHTCSSCCCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHTTCHHHHHHHHHHHHHHHHHHHHHHHHHHHHIIIIIGGGCCSHHHHHHHHHHHHHHHHHHHHHCCHHHHHHHHHHHHHHHHHHHHHHHHHSCTTCHHHHHHHHHHHHHIIIIICCHHHHHHHHHHHHHHHHTTCTTCCSSHHHHHHHHHHHHHHHHHHHC\"\n",
+ "print(\"Secondary structure of protein: (H: Alpha Helix, E: Beta Strand, C: Coil) \\n\\t\", helix_shortening_ss8)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The helix-coil-helix region in the original protein is 73 residues long. We will try to shorten it to 45 residues by prompting the model with partial sequence and secondary structure"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "shortened_region_length = 45\n",
+ "\n",
+ "# We'll construct a sequence prompt that masks the (shortened) helix-coil-helix region, but leaves the flanking regions unmasked\n",
+ "sequence_prompt = helix_shortening_chain.sequence[:helix_region[0]] + \"_\" * shortened_region_length + helix_shortening_chain.sequence[helix_region[-1] + 1:]\n",
+ "print(\"Sequence prompt:\\n\\t\", sequence_prompt)\n",
+ "\n",
+ "# We'll construct a secondary structure prompt that retains the secondary structure of the flanking regions, and shortens the lengths of helices in the helix-coil-helix region\n",
+ "ss8_prompt = helix_shortening_ss8[:helix_region[0]] + (((shortened_region_length - 3) // 2) * \"H\" + \"C\"*3 + ((shortened_region_length - 3) // 2) * \"H\") + helix_shortening_ss8[helix_region[-1] + 1:]\n",
+ "print(\"SS8 prompt:\\n\\t\", ss8_prompt)\n",
+ "print(\"Proposed SS8 for shortened helix-coil-helix region:\\n\\t\", \" \"*helix_region[0] + ss8_prompt[helix_region[0]:helix_region[0]+45])\n",
+ "\n",
+ "print(\"\")\n",
+ "print(\"Original sequence:\\n\\t\", helix_shortening_chain.sequence)\n",
+ "print(\"Original SS8:\\n\\t\", helix_shortening_ss8)\n",
+ "print(\"Original SS8 for helix-coil-helix region:\\n\\t\", \" \"*helix_region[0] + helix_shortening_ss8[helix_region[0]:helix_region[-1]+1])\n",
+ "\n",
+ "\n",
+ "# We can again use the ESMProtein class to compose the sequence and secondary structure prompts into a single prompt that can be passed to ESM3\n",
+ "protein_prompt = ESMProtein(sequence=sequence_prompt, secondary_structure=ss8_prompt)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We can again use the `generate` method of the model to iteratively decode a protein sequence based on the prompt"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(\"Generating protein sequence...\")\n",
+ "sequence_generation = model.generate(protein_prompt, GenerationConfig(track=\"sequence\", num_steps=protein_prompt.sequence.count(\"_\") // 2, temperature=0.5))\n",
+ "print(\"Folding protein...\")\n",
+ "structure_prediction = model.generate(ESMProtein(sequence=sequence_generation.sequence), GenerationConfig(track=\"structure\", num_steps=len(protein_prompt) // 4, temperature=0))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now, we can visualize the generated structure using `py3Dmol`. We'll visualize the generated structure (right) alongside the original structure (left) from which the motif was drawn. The helix-coil-helix region in the original structure is colored in blue and the shortened region in the generated structure is colored in pink."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "predicted_chain = structure_prediction.to_protein_chain()\n",
+ "predicted_chain = predicted_chain.align(helix_shortening_chain, mobile_inds=np.arange(len(predicted_chain) - 120, len(predicted_chain)), target_inds=np.arange(len(helix_shortening_chain) - 120, len(helix_shortening_chain)))\n",
+ "view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))\n",
+ "view.addModel(helix_shortening_chain.to_pdb_string(), \"pdb\", viewer=(0, 0))\n",
+ "view.addModel(predicted_chain.to_pdb_string(), \"pdb\", viewer=(0, 1))\n",
+ "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n",
+ "view.addStyle({\"resi\": (helix_region + 1).tolist()}, {\"cartoon\": {\"color\":\"lightblue\"}},viewer=(0, 0))\n",
+ "view.addStyle({\"resi\": (np.arange(helix_region[0], helix_region[0] + 45) + 1).tolist()}, {\"cartoon\": {\"color\":\"pink\"}},viewer=(0, 1))\n",
+ "view.zoomTo()\n",
+ "view.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# SASA Editing Example: Exposing a buried helix"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's grab 1LBS from the PDB and visualize it using `py3Dmol`. 1LBS has an alternating alpha-beta sandwich fold, with a buried helix in the center, highlighted in red"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "lipase_chain = ProteinChain.from_rcsb(\"1LBS\", \"A\")\n",
+ "span_start = 105\n",
+ "span_end = 116\n",
+ "view = py3Dmol.view(width=500, height=500)\n",
+ "view.addModel(lipase_chain.to_pdb_string(), \"pdb\")\n",
+ "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n",
+ "view.addStyle({\"resi\": (np.arange(span_start, span_end) + 1).tolist()}, {\"cartoon\": {\"color\":\"red\"}})\n",
+ "view.zoomTo()\n",
+ "view.show()\n",
+ "lipase_ss8 = \"CCSSCCCCSSCHHHHHHTEEETTBBTTBCSSEEEEECCTTCCHHHHHTTTHHHHHHHTTCEEEEECCTTTTCSCHHHHHHHHHHHHHHHHHHTTSCCEEEEEETHHHHHHHHHHHHCGGGGGTEEEEEEESCCTTCBGGGHHHHHTTCBCHHHHHTBTTCHHHHHHHHTTTTBCSSCEEEEECTTCSSSCCCCSSSTTSTTCCBTSEEEEHHHHHCTTCCCCSHHHHHBHHHHHHHHHHHHCTTSSCCGGGCCSTTCCCSBCTTSCHHHHHHHHSTHHHHHHHHHHSCCBSSCCCCCGGGGGGSTTCEETTEECCC\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We can construct a multimodal prompt for ESM3 to instruct it to expose the buried helix as follows:\n",
+ "1. Prompt with the **structure** of the buried helix highlighted in red -- this will prompt ESM3 to generate a protein that contains that same helix\n",
+ "2. Prompt with high **SASA** values for the residues in the buried helix -- this will prompt ESM3 to expose the helix to the surface of the protein"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "structure_prompt = torch.full((len(lipase_chain), 37, 3), torch.nan)\n",
+ "structure_prompt[span_start:span_end] = torch.tensor(lipase_chain[span_start:span_end].atom37_positions, dtype=torch.float32) \n",
+ "\n",
+ "sasa_prompt = [None]*len(lipase_chain)\n",
+ "sasa_prompt[span_start:span_end] = [40.0]*(span_end - span_start)\n",
+ "\n",
+ "print(\"SASA prompt (just for buried region): \", sasa_prompt[span_start:span_end])\n",
+ "\n",
+ "protein_prompt = ESMProtein(sequence=\"_\"*len(lipase_chain), coordinates=structure_prompt, sasa=sasa_prompt)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This is a more difficult task, so you may need to sample more generations from ESM before you find a solution. We'll sample 32 here and sort by the generations with the highest predicted TM-score (pTM) by ESM3. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "generated_proteins = []\n",
+ "N_SAMPLES = 16\n",
+ "for i in range(N_SAMPLES):\n",
+ " print(\"Generating protein sequence...\")\n",
+ " sequence_generation = model.generate(protein_prompt, GenerationConfig(track=\"sequence\", num_steps=len(protein_prompt) // 8, temperature=0.7))\n",
+ " print(\"Folding protein...\")\n",
+ " structure_prediction = model.generate(ESMProtein(sequence=sequence_generation.sequence), GenerationConfig(track=\"structure\", num_steps=len(protein_prompt) // 32))\n",
+ " generated_proteins.append(structure_prediction)\n",
+ "\n",
+ "# Sort generations by ptm\n",
+ "generated_proteins = sorted(generated_proteins, key=lambda x: x.ptm.item(), reverse=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's visualize the top 4 generations by pTM, alongside with the original protein (on the left)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "N_SAMPLES_TO_SHOW = 4\n",
+ "view = py3Dmol.view(width=1000, height=500, viewergrid=(1, N_SAMPLES_TO_SHOW+1))\n",
+ "view.addModel(lipase_chain.to_pdb_string(), \"pdb\", viewer=(0, 0))\n",
+ "for i in range(N_SAMPLES_TO_SHOW):\n",
+ " print(\"PTM of generated protein {}: {:.2f}\".format(i+1, generated_proteins[i].ptm.item()))\n",
+ " view.addModel(generated_proteins[i].to_protein_chain().to_pdb_string(), \"pdb\", viewer=(0, i+1))\n",
+ "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n",
+ "view.addStyle({\"resi\": (np.arange(span_start, span_end) + 1).tolist()}, {\"cartoon\": {\"color\": \"red\"}})\n",
+ "view.zoomTo()\n",
+ "view.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
}
diff --git a/examples/gfp_design.ipynb b/examples/gfp_design.ipynb
index 96cf598..6bc0e08 100644
--- a/examples/gfp_design.ipynb
+++ b/examples/gfp_design.ipynb
@@ -121,7 +121,7 @@
"cell_type": "code",
"source": [
"model = client(\n",
- " model=\"esm3-md-alpha1\",\n",
+ " model=\"esm3-medium-2024-03\",\n",
" url=\"https://forge.evolutionaryscale.ai\",\n",
" token=token,\n",
")"
diff --git a/pyproject.toml b/pyproject.toml
index b8bd44f..741827f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "esm"
-version = "3.0.4"
+version = "3.0.5"
description = "EvolutionaryScale open model repository"
readme = "README.md"
requires-python = ">=3.10"