ported over updates from internal (#208)

Co-authored-by: chetan <chetan@evolutionaryscale.ai>
This commit is contained in:
Chetan Mishra
2025-03-05 15:08:52 -05:00
committed by GitHub
parent 82b1431b7b
commit c6b8c342aa
13 changed files with 2392 additions and 53 deletions

3
.gitignore vendored
View File

@@ -1 +1,4 @@
esm.egg-info
# pixi environments
.pixi
*.egg-info

View File

@@ -1,13 +1,18 @@
import os
from esm.models.esmc import ESMC
from esm.sdk import client
from esm.sdk.api import (
ESMCInferenceClient,
ESMProtein,
ESMProteinTensor,
LogitsConfig,
LogitsOutput,
)
from esm.sdk.forge import ESM3ForgeInferenceClient
def main(client: ESMCInferenceClient):
def main(client: ESMCInferenceClient | ESM3ForgeInferenceClient):
# ================================================================
# Example usage: one single protein
# ================================================================
@@ -15,13 +20,16 @@ def main(client: ESMCInferenceClient):
# Use logits endpoint. Using bf16 for inference optimization
protein_tensor = client.encode(protein)
assert isinstance(
protein_tensor, ESMProteinTensor
), f"Expected ESMProteinTensor but got error: {protein_tensor}"
output = client.logits(
protein_tensor,
LogitsConfig(sequence=True, return_embeddings=True, return_hidden_states=True),
)
assert isinstance(
output, LogitsOutput
), f"LogitsOutput was expected but got {output}"
), f"LogitsOutput was expected but got error: {output}"
assert output.logits is not None and output.logits.sequence is not None
assert output.embeddings is not None
assert output.hidden_states is not None
@@ -30,9 +38,15 @@ def main(client: ESMCInferenceClient):
)
# request a specific hidden layer.
assert isinstance(
protein_tensor, ESMProteinTensor
), f"Expected ESMProteinTensor but got error: {protein_tensor}"
output = client.logits(
protein_tensor, LogitsConfig(return_hidden_states=True, ith_hidden_layer=1)
)
assert isinstance(
output, LogitsOutput
), f"LogitsOutput was expected but got error: {output}"
assert output.hidden_states is not None
print(f"Client returned hidden states with shape {output.hidden_states.shape}")
@@ -57,6 +71,15 @@ def raw_forward(model: ESMC):
if __name__ == "__main__":
model = ESMC.from_pretrained("esmc_300m")
main(model)
raw_forward(model)
if os.environ.get("ESM_API_KEY", ""):
print("ESM_API_KEY found. Trying to use model from Forge...")
main(client(model="esmc-300m-2024-12"))
else:
print("No ESM_API_KEY found. Trying to load model locally...")
print(
"TO try this script with a Forge API, please run ESM_API_KEY=your_api_key python esm3.py"
)
main(ESMC.from_pretrained("esm3_sm_open_v1"))
model = ESMC.from_pretrained("esmc_300m")
main(model)
raw_forward(model)

View File

@@ -50,7 +50,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 13,
"metadata": {
"id": "poK5NTaXRGcX"
},
@@ -85,7 +85,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 14,
"metadata": {
"id": "zNrU9Q2SYonX"
},
@@ -105,7 +105,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 26,
"metadata": {
"id": "Tna_mjGOjdXA"
},
@@ -367,7 +367,8 @@
"source": [
"%%time\n",
"\n",
"num_tokens_to_decode = (prompt.sequence == 32).sum().item()\n",
"# Based on internal, there's not a benefit to iterative decoding past 20 steps\n",
"num_tokens_to_decode = min((prompt.sequence == 32).sum().item(), 20)\n",
"\n",
"sequence_generation = model.generate(\n",
" # Generate a sequence.\n",
@@ -380,7 +381,7 @@
"length_of_sequence = sequence_generation.sequence.numel() - 2\n",
"sequence_generation = model.generate(\n",
" sequence_generation,\n",
" GenerationConfig(track=\"structure\", num_steps=length_of_sequence, temperature=0.0),\n",
" GenerationConfig(track=\"structure\", num_steps=1, temperature=0.0),\n",
")\n",
"\n",
"# Decode to AA string and coordinates.\n",
@@ -528,11 +529,21 @@
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"display_name": "default",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python"
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
}
},
"nbformat": 4,

View File

@@ -1,2 +1,2 @@
__version__ = "3.1.4"
__version__ = "3.1.5"

View File

@@ -266,7 +266,8 @@ class ESMProteinError(Exception, ProteinType):
@define
class GenerationConfig:
track: str = ""
invalid_ids: Sequence[int] = []
# By default avoid sampling the amino acid "X"
invalid_ids: Sequence[int] = [24]
# Controls the number of tokens to unmask during each round of iterative generation.
schedule: str = attr.field(
validator=attr.validators.in_(["cosine", "linear"]), default="cosine"
@@ -275,11 +276,11 @@ class GenerationConfig:
# "random" will unmask a correct number of tokens randomly.
# "entropy" will unmask the tokens with the lowest logit entropy first.
strategy: str = attr.field(
validator=attr.validators.in_(["random", "entropy"]), default="entropy"
validator=attr.validators.in_(["random", "entropy"]), default="random"
)
# Set this to a higher value for better generation results.
# Setting default to 20, as there is diminishing return for decoding steps more than 20.
# Note that this needs to be less than or equal to the sequence length.
num_steps: int = 1
num_steps: int = 20
temperature: float = 1.0
temperature_annealing: bool = False
top_p: float = 1.0

View File

@@ -60,6 +60,10 @@ def log_retry_attempt(retry_state):
def _validate_protein_tensor_input(input):
if isinstance(input, ESMProteinError):
raise ValueError(
f"Input must be an ESMProteinTensor instance, but received an ESMProteinError instead: {input.error_code} {input.error_msg}"
)
if not isinstance(input, ESMProteinTensor):
raise ValueError(
f"Input must be an ESMProteinTensor instance, but received {type(input)} instead. "
@@ -71,14 +75,25 @@ class SequenceStructureForgeInferenceClient:
def __init__(
self,
url: str = "https://forge.evolutionaryscale.ai",
model: str | None = None,
token: str = "",
request_timeout: int | None = None,
):
"""
Forge client for folding and inverse folding between sequence and structure spaces.
Args:
url: URL of the Forge server.
model: Name of the model to be used for folding / inv folding.
token: API token.
request_timeout: Override the system default request timeout, in seconds.
"""
if token == "":
raise RuntimeError(
"Please provide a token to connect to Forge via token=YOUR_API_TOKEN_HERE"
)
self.url = url
self.model = model
self.token = token
self.headers = {"Authorization": f"Bearer {self.token}"}
self.request_timeout = request_timeout
@@ -89,9 +104,19 @@ class SequenceStructureForgeInferenceClient:
potential_sequence_of_concern: bool,
model_name: str | None = None,
) -> ESMProtein | ESMProteinError:
"""Predict coordinates for a protein sequence.
Args:
sequence: Protein sequence to be folded.
potential_sequence_of_concern: Self disclosed potential_of_concern bit.
This bit is largely ignored by the fold() endpoint.
model_name: Override the client level model name if needed.
"""
request = {"sequence": sequence}
if model_name is not None:
request["model"] = model_name
elif self.model is not None:
request["model"] = self.model
try:
data = self._post("fold", request, potential_sequence_of_concern)
except ESMProteinError as e:
@@ -109,6 +134,17 @@ class SequenceStructureForgeInferenceClient:
potential_sequence_of_concern: bool,
model_name: str | None = None,
) -> ESMProtein | ESMProteinError:
"""Generate protein sequence from its structure.
This endpoint is only supported by generative models like ESM3.
Args:
coordinates: Protein sequence coordinates to be inversely folded.
config: Configurations related to inverse folding generation.
potential_sequence_of_concern: Self disclosed potential_of_concern bit.
Requires special permission to use.
model_name: Override the client level model name if needed.
"""
inverse_folding_config = {
"invalid_ids": config.invalid_ids,
"temperature": config.temperature,
@@ -119,6 +155,8 @@ class SequenceStructureForgeInferenceClient:
}
if model_name is not None:
request["model"] = model_name
elif self.model is not None:
request["model"] = self.model
try:
data = self._post("inverse_fold", request, potential_sequence_of_concern)
except ESMProteinError as e:
@@ -208,6 +246,16 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
@retry_decorator
def generate(self, input: ProteinType, config: GenerationConfig) -> ProteinType:
if isinstance(input, ESMProteinError):
raise ValueError(
f"Input must be an ESMProtein or ESMProteinTensor instance, but received an ESMProteinError instead: {input.error_code} {input.error_msg}"
)
assert isinstance(input, ESMProtein) or isinstance(input, ESMProteinTensor)
if input.sequence is not None and config.num_steps > len(input.sequence):
config.num_steps = len(input.sequence)
print(
"Warning: num_steps cannot exceed sequence length. Setting num_steps to sequence length."
)
if isinstance(input, ESMProtein):
output = self.__generate_protein(input, config)
elif isinstance(input, ESMProteinTensor):

View File

@@ -9,14 +9,14 @@ from esm.sdk.forge import (
class SequenceStructureSageMakerClient(SequenceStructureForgeInferenceClient):
def __init__(self, endpoint_name: str):
def __init__(self, endpoint_name: str, model: str | None = None):
"""SequenceStructure (folding and inverse folding) client that talks to a SageMaker endpoint.
Args:
endpoint_name: Name of the SageMaker endpoint.
"""
# Dummy URL and token to make SequenceStructureForgeInferenceClient happy.
super().__init__(url="", token="dummy")
super().__init__(url="", model=model, token="dummy")
self._endpoint_name = endpoint_name

View File

@@ -72,13 +72,45 @@ class EsmSequenceTokenizer(PreTrainedTokenizerFast, EsmTokenizerBase):
def bos_token_id(self):
return self.cls_token_id
@property
def cls_token(self):
return self._get_token("cls_token")
@property
def cls_token_id(self):
return self._get_token_id(self.cls_token)
@property
def eos_token(self):
return self._get_token("eos_token")
@property
def eos_token_id(self):
return self._get_token_id(self.eos_token)
@property
def mask_token(self):
return self._get_token("mask_token")
@property
def mask_token_id(self):
return self._get_token_id(self.mask_token)
@property
def pad_token(self):
return self._get_token("pad_token")
@property
def pad_token_id(self):
return self._get_token_id(self.pad_token)
@property
def chain_break_token(self):
return self.cb_token
@property
def chain_break_token_id(self):
return self.convert_tokens_to_ids(self.chain_break_token)
return self._get_token_id(self.chain_break_token)
@property
def all_token_ids(self):
@@ -87,3 +119,16 @@ class EsmSequenceTokenizer(PreTrainedTokenizerFast, EsmTokenizerBase):
@property
def special_token_ids(self):
return self.all_special_ids
def _get_token_id(self, token) -> int:
token_id = self.convert_tokens_to_ids(token)
assert isinstance(token_id, int)
return token_id
def _get_token(self, token_name: str) -> str:
# NOTE: Tokenizers library overloads __getattr__ to expose special tokens
# Adding a helper method around it keeps the base class functionality without overriding
# the property. See: https://github.com/huggingface/transformers/blob/41925e42135257361b7f02aa20e3bbdab3f7b923/src/transformers/tokenization_utils_base.py#L1086
token_str = self.__getattr__(token_name)
assert isinstance(token_str, str)
return token_str

View File

@@ -3,40 +3,21 @@ from typing import Protocol, runtime_checkable
@runtime_checkable
class EsmTokenizerBase(Protocol):
mask_token: str
mask_token_id: int
bos_token: str
bos_token_id: int
eos_token: str
eos_token_id: int
pad_token: str
pad_token_id: int
chain_break_token: str
chain_break_token_id: int
def encode(self, *args, **kwargs): ...
def decode(self, *args, **kwargs): ...
@property
def mask_token(self) -> str: ...
@property
def mask_token_id(self) -> int: ...
@property
def bos_token(self) -> str: ...
@property
def bos_token_id(self) -> int: ...
@property
def eos_token(self) -> str: ...
@property
def eos_token_id(self) -> int: ...
@property
def pad_token(self) -> str: ...
@property
def pad_token_id(self) -> int: ...
@property
def chain_break_token(self) -> str: ...
@property
def chain_break_token_id(self) -> int: ...
@property
def all_token_ids(self): ...

View File

@@ -1,6 +1,7 @@
import pytest
import torch
from evolutionaryscale.models.esm3v2 import Esm3v2
from esm.sdk.api import (
ESMProtein,
ESMProteinTensor,
@@ -18,6 +19,7 @@ def esm3_remote_inference_client():
model = _load_esm_model(
ModelName.ESM3_TINY_DEV, distributed_model=False, load_function_decoder=False
)
assert isinstance(model, Esm3v2)
client = ESM3RemoteModelInferenceClient(
model,
tokenizers=model.tokenizers,

2207
pixi.lock Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,6 @@
[project]
name = "esm"
version = "3.1.4"
version = "3.1.5"
description = "EvolutionaryScale open model repository"
readme = "README.md"
requires-python = ">=3.10"
@@ -49,3 +49,14 @@ include = ["esm*"]
[tool.setuptools.package-data]
esm = ["data/*"]
[tool.pixi.project]
channels = ["conda-forge"]
platforms = ["linux-64"]
[tool.pixi.dependencies]
pkg-config = "*"
cmake = "*"
[tool.pixi.pypi-dependencies]
esm = { path = ".", editable = true }

View File

@@ -72,6 +72,13 @@
"client_container = ClientInitContainer()\n",
"create_generation_ui(get_forge_client(model_name))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
@@ -79,7 +86,7 @@
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "default",
"language": "python",
"name": "python3"
},