mirror of
https://github.com/evolutionaryscale/esm.git
synced 2026-06-04 17:14:23 +08:00
ported over updates from internal (#208)
Co-authored-by: chetan <chetan@evolutionaryscale.ai>
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1 +1,4 @@
|
||||
esm.egg-info
|
||||
# pixi environments
|
||||
.pixi
|
||||
*.egg-info
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
import os
|
||||
|
||||
from esm.models.esmc import ESMC
|
||||
from esm.sdk import client
|
||||
from esm.sdk.api import (
|
||||
ESMCInferenceClient,
|
||||
ESMProtein,
|
||||
ESMProteinTensor,
|
||||
LogitsConfig,
|
||||
LogitsOutput,
|
||||
)
|
||||
from esm.sdk.forge import ESM3ForgeInferenceClient
|
||||
|
||||
|
||||
def main(client: ESMCInferenceClient):
|
||||
def main(client: ESMCInferenceClient | ESM3ForgeInferenceClient):
|
||||
# ================================================================
|
||||
# Example usage: one single protein
|
||||
# ================================================================
|
||||
@@ -15,13 +20,16 @@ def main(client: ESMCInferenceClient):
|
||||
|
||||
# Use logits endpoint. Using bf16 for inference optimization
|
||||
protein_tensor = client.encode(protein)
|
||||
assert isinstance(
|
||||
protein_tensor, ESMProteinTensor
|
||||
), f"Expected ESMProteinTensor but got error: {protein_tensor}"
|
||||
output = client.logits(
|
||||
protein_tensor,
|
||||
LogitsConfig(sequence=True, return_embeddings=True, return_hidden_states=True),
|
||||
)
|
||||
assert isinstance(
|
||||
output, LogitsOutput
|
||||
), f"LogitsOutput was expected but got {output}"
|
||||
), f"LogitsOutput was expected but got error: {output}"
|
||||
assert output.logits is not None and output.logits.sequence is not None
|
||||
assert output.embeddings is not None
|
||||
assert output.hidden_states is not None
|
||||
@@ -30,9 +38,15 @@ def main(client: ESMCInferenceClient):
|
||||
)
|
||||
|
||||
# request a specific hidden layer.
|
||||
assert isinstance(
|
||||
protein_tensor, ESMProteinTensor
|
||||
), f"Expected ESMProteinTensor but got error: {protein_tensor}"
|
||||
output = client.logits(
|
||||
protein_tensor, LogitsConfig(return_hidden_states=True, ith_hidden_layer=1)
|
||||
)
|
||||
assert isinstance(
|
||||
output, LogitsOutput
|
||||
), f"LogitsOutput was expected but got error: {output}"
|
||||
assert output.hidden_states is not None
|
||||
print(f"Client returned hidden states with shape {output.hidden_states.shape}")
|
||||
|
||||
@@ -57,6 +71,15 @@ def raw_forward(model: ESMC):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = ESMC.from_pretrained("esmc_300m")
|
||||
main(model)
|
||||
raw_forward(model)
|
||||
if os.environ.get("ESM_API_KEY", ""):
|
||||
print("ESM_API_KEY found. Trying to use model from Forge...")
|
||||
main(client(model="esmc-300m-2024-12"))
|
||||
else:
|
||||
print("No ESM_API_KEY found. Trying to load model locally...")
|
||||
print(
|
||||
"TO try this script with a Forge API, please run ESM_API_KEY=your_api_key python esm3.py"
|
||||
)
|
||||
main(ESMC.from_pretrained("esm3_sm_open_v1"))
|
||||
model = ESMC.from_pretrained("esmc_300m")
|
||||
main(model)
|
||||
raw_forward(model)
|
||||
|
||||
@@ -50,7 +50,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 13,
|
||||
"metadata": {
|
||||
"id": "poK5NTaXRGcX"
|
||||
},
|
||||
@@ -85,7 +85,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 14,
|
||||
"metadata": {
|
||||
"id": "zNrU9Q2SYonX"
|
||||
},
|
||||
@@ -105,7 +105,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 26,
|
||||
"metadata": {
|
||||
"id": "Tna_mjGOjdXA"
|
||||
},
|
||||
@@ -367,7 +367,8 @@
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"\n",
|
||||
"num_tokens_to_decode = (prompt.sequence == 32).sum().item()\n",
|
||||
"# Based on internal, there's not a benefit to iterative decoding past 20 steps\n",
|
||||
"num_tokens_to_decode = min((prompt.sequence == 32).sum().item(), 20)\n",
|
||||
"\n",
|
||||
"sequence_generation = model.generate(\n",
|
||||
" # Generate a sequence.\n",
|
||||
@@ -380,7 +381,7 @@
|
||||
"length_of_sequence = sequence_generation.sequence.numel() - 2\n",
|
||||
"sequence_generation = model.generate(\n",
|
||||
" sequence_generation,\n",
|
||||
" GenerationConfig(track=\"structure\", num_steps=length_of_sequence, temperature=0.0),\n",
|
||||
" GenerationConfig(track=\"structure\", num_steps=1, temperature=0.0),\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Decode to AA string and coordinates.\n",
|
||||
@@ -528,11 +529,21 @@
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "default",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
__version__ = "3.1.4"
|
||||
__version__ = "3.1.5"
|
||||
|
||||
|
||||
@@ -266,7 +266,8 @@ class ESMProteinError(Exception, ProteinType):
|
||||
@define
|
||||
class GenerationConfig:
|
||||
track: str = ""
|
||||
invalid_ids: Sequence[int] = []
|
||||
# By default avoid sampling the amino acid "X"
|
||||
invalid_ids: Sequence[int] = [24]
|
||||
# Controls the number of tokens to unmask during each round of iterative generation.
|
||||
schedule: str = attr.field(
|
||||
validator=attr.validators.in_(["cosine", "linear"]), default="cosine"
|
||||
@@ -275,11 +276,11 @@ class GenerationConfig:
|
||||
# "random" will unmask a correct number of tokens randomly.
|
||||
# "entropy" will unmask the tokens with the lowest logit entropy first.
|
||||
strategy: str = attr.field(
|
||||
validator=attr.validators.in_(["random", "entropy"]), default="entropy"
|
||||
validator=attr.validators.in_(["random", "entropy"]), default="random"
|
||||
)
|
||||
# Set this to a higher value for better generation results.
|
||||
# Setting default to 20, as there is diminishing return for decoding steps more than 20.
|
||||
# Note that this needs to be less than or equal to the sequence length.
|
||||
num_steps: int = 1
|
||||
num_steps: int = 20
|
||||
temperature: float = 1.0
|
||||
temperature_annealing: bool = False
|
||||
top_p: float = 1.0
|
||||
|
||||
@@ -60,6 +60,10 @@ def log_retry_attempt(retry_state):
|
||||
|
||||
|
||||
def _validate_protein_tensor_input(input):
|
||||
if isinstance(input, ESMProteinError):
|
||||
raise ValueError(
|
||||
f"Input must be an ESMProteinTensor instance, but received an ESMProteinError instead: {input.error_code} {input.error_msg}"
|
||||
)
|
||||
if not isinstance(input, ESMProteinTensor):
|
||||
raise ValueError(
|
||||
f"Input must be an ESMProteinTensor instance, but received {type(input)} instead. "
|
||||
@@ -71,14 +75,25 @@ class SequenceStructureForgeInferenceClient:
|
||||
def __init__(
|
||||
self,
|
||||
url: str = "https://forge.evolutionaryscale.ai",
|
||||
model: str | None = None,
|
||||
token: str = "",
|
||||
request_timeout: int | None = None,
|
||||
):
|
||||
"""
|
||||
Forge client for folding and inverse folding between sequence and structure spaces.
|
||||
|
||||
Args:
|
||||
url: URL of the Forge server.
|
||||
model: Name of the model to be used for folding / inv folding.
|
||||
token: API token.
|
||||
request_timeout: Override the system default request timeout, in seconds.
|
||||
"""
|
||||
if token == "":
|
||||
raise RuntimeError(
|
||||
"Please provide a token to connect to Forge via token=YOUR_API_TOKEN_HERE"
|
||||
)
|
||||
self.url = url
|
||||
self.model = model
|
||||
self.token = token
|
||||
self.headers = {"Authorization": f"Bearer {self.token}"}
|
||||
self.request_timeout = request_timeout
|
||||
@@ -89,9 +104,19 @@ class SequenceStructureForgeInferenceClient:
|
||||
potential_sequence_of_concern: bool,
|
||||
model_name: str | None = None,
|
||||
) -> ESMProtein | ESMProteinError:
|
||||
"""Predict coordinates for a protein sequence.
|
||||
|
||||
Args:
|
||||
sequence: Protein sequence to be folded.
|
||||
potential_sequence_of_concern: Self disclosed potential_of_concern bit.
|
||||
This bit is largely ignored by the fold() endpoint.
|
||||
model_name: Override the client level model name if needed.
|
||||
"""
|
||||
request = {"sequence": sequence}
|
||||
if model_name is not None:
|
||||
request["model"] = model_name
|
||||
elif self.model is not None:
|
||||
request["model"] = self.model
|
||||
try:
|
||||
data = self._post("fold", request, potential_sequence_of_concern)
|
||||
except ESMProteinError as e:
|
||||
@@ -109,6 +134,17 @@ class SequenceStructureForgeInferenceClient:
|
||||
potential_sequence_of_concern: bool,
|
||||
model_name: str | None = None,
|
||||
) -> ESMProtein | ESMProteinError:
|
||||
"""Generate protein sequence from its structure.
|
||||
|
||||
This endpoint is only supported by generative models like ESM3.
|
||||
|
||||
Args:
|
||||
coordinates: Protein sequence coordinates to be inversely folded.
|
||||
config: Configurations related to inverse folding generation.
|
||||
potential_sequence_of_concern: Self disclosed potential_of_concern bit.
|
||||
Requires special permission to use.
|
||||
model_name: Override the client level model name if needed.
|
||||
"""
|
||||
inverse_folding_config = {
|
||||
"invalid_ids": config.invalid_ids,
|
||||
"temperature": config.temperature,
|
||||
@@ -119,6 +155,8 @@ class SequenceStructureForgeInferenceClient:
|
||||
}
|
||||
if model_name is not None:
|
||||
request["model"] = model_name
|
||||
elif self.model is not None:
|
||||
request["model"] = self.model
|
||||
try:
|
||||
data = self._post("inverse_fold", request, potential_sequence_of_concern)
|
||||
except ESMProteinError as e:
|
||||
@@ -208,6 +246,16 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
||||
|
||||
@retry_decorator
|
||||
def generate(self, input: ProteinType, config: GenerationConfig) -> ProteinType:
|
||||
if isinstance(input, ESMProteinError):
|
||||
raise ValueError(
|
||||
f"Input must be an ESMProtein or ESMProteinTensor instance, but received an ESMProteinError instead: {input.error_code} {input.error_msg}"
|
||||
)
|
||||
assert isinstance(input, ESMProtein) or isinstance(input, ESMProteinTensor)
|
||||
if input.sequence is not None and config.num_steps > len(input.sequence):
|
||||
config.num_steps = len(input.sequence)
|
||||
print(
|
||||
"Warning: num_steps cannot exceed sequence length. Setting num_steps to sequence length."
|
||||
)
|
||||
if isinstance(input, ESMProtein):
|
||||
output = self.__generate_protein(input, config)
|
||||
elif isinstance(input, ESMProteinTensor):
|
||||
|
||||
@@ -9,14 +9,14 @@ from esm.sdk.forge import (
|
||||
|
||||
|
||||
class SequenceStructureSageMakerClient(SequenceStructureForgeInferenceClient):
|
||||
def __init__(self, endpoint_name: str):
|
||||
def __init__(self, endpoint_name: str, model: str | None = None):
|
||||
"""SequenceStructure (folding and inverse folding) client that talks to a SageMaker endpoint.
|
||||
|
||||
Args:
|
||||
endpoint_name: Name of the SageMaker endpoint.
|
||||
"""
|
||||
# Dummy URL and token to make SequenceStructureForgeInferenceClient happy.
|
||||
super().__init__(url="", token="dummy")
|
||||
super().__init__(url="", model=model, token="dummy")
|
||||
|
||||
self._endpoint_name = endpoint_name
|
||||
|
||||
|
||||
@@ -72,13 +72,45 @@ class EsmSequenceTokenizer(PreTrainedTokenizerFast, EsmTokenizerBase):
|
||||
def bos_token_id(self):
|
||||
return self.cls_token_id
|
||||
|
||||
@property
|
||||
def cls_token(self):
|
||||
return self._get_token("cls_token")
|
||||
|
||||
@property
|
||||
def cls_token_id(self):
|
||||
return self._get_token_id(self.cls_token)
|
||||
|
||||
@property
|
||||
def eos_token(self):
|
||||
return self._get_token("eos_token")
|
||||
|
||||
@property
|
||||
def eos_token_id(self):
|
||||
return self._get_token_id(self.eos_token)
|
||||
|
||||
@property
|
||||
def mask_token(self):
|
||||
return self._get_token("mask_token")
|
||||
|
||||
@property
|
||||
def mask_token_id(self):
|
||||
return self._get_token_id(self.mask_token)
|
||||
|
||||
@property
|
||||
def pad_token(self):
|
||||
return self._get_token("pad_token")
|
||||
|
||||
@property
|
||||
def pad_token_id(self):
|
||||
return self._get_token_id(self.pad_token)
|
||||
|
||||
@property
|
||||
def chain_break_token(self):
|
||||
return self.cb_token
|
||||
|
||||
@property
|
||||
def chain_break_token_id(self):
|
||||
return self.convert_tokens_to_ids(self.chain_break_token)
|
||||
return self._get_token_id(self.chain_break_token)
|
||||
|
||||
@property
|
||||
def all_token_ids(self):
|
||||
@@ -87,3 +119,16 @@ class EsmSequenceTokenizer(PreTrainedTokenizerFast, EsmTokenizerBase):
|
||||
@property
|
||||
def special_token_ids(self):
|
||||
return self.all_special_ids
|
||||
|
||||
def _get_token_id(self, token) -> int:
|
||||
token_id = self.convert_tokens_to_ids(token)
|
||||
assert isinstance(token_id, int)
|
||||
return token_id
|
||||
|
||||
def _get_token(self, token_name: str) -> str:
|
||||
# NOTE: Tokenizers library overloads __getattr__ to expose special tokens
|
||||
# Adding a helper method around it keeps the base class functionality without overriding
|
||||
# the property. See: https://github.com/huggingface/transformers/blob/41925e42135257361b7f02aa20e3bbdab3f7b923/src/transformers/tokenization_utils_base.py#L1086
|
||||
token_str = self.__getattr__(token_name)
|
||||
assert isinstance(token_str, str)
|
||||
return token_str
|
||||
|
||||
@@ -3,40 +3,21 @@ from typing import Protocol, runtime_checkable
|
||||
|
||||
@runtime_checkable
|
||||
class EsmTokenizerBase(Protocol):
|
||||
mask_token: str
|
||||
mask_token_id: int
|
||||
bos_token: str
|
||||
bos_token_id: int
|
||||
eos_token: str
|
||||
eos_token_id: int
|
||||
pad_token: str
|
||||
pad_token_id: int
|
||||
chain_break_token: str
|
||||
chain_break_token_id: int
|
||||
|
||||
def encode(self, *args, **kwargs): ...
|
||||
|
||||
def decode(self, *args, **kwargs): ...
|
||||
|
||||
@property
|
||||
def mask_token(self) -> str: ...
|
||||
|
||||
@property
|
||||
def mask_token_id(self) -> int: ...
|
||||
|
||||
@property
|
||||
def bos_token(self) -> str: ...
|
||||
|
||||
@property
|
||||
def bos_token_id(self) -> int: ...
|
||||
|
||||
@property
|
||||
def eos_token(self) -> str: ...
|
||||
|
||||
@property
|
||||
def eos_token_id(self) -> int: ...
|
||||
|
||||
@property
|
||||
def pad_token(self) -> str: ...
|
||||
|
||||
@property
|
||||
def pad_token_id(self) -> int: ...
|
||||
|
||||
@property
|
||||
def chain_break_token(self) -> str: ...
|
||||
|
||||
@property
|
||||
def chain_break_token_id(self) -> int: ...
|
||||
|
||||
@property
|
||||
def all_token_ids(self): ...
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from evolutionaryscale.models.esm3v2 import Esm3v2
|
||||
from esm.sdk.api import (
|
||||
ESMProtein,
|
||||
ESMProteinTensor,
|
||||
@@ -18,6 +19,7 @@ def esm3_remote_inference_client():
|
||||
model = _load_esm_model(
|
||||
ModelName.ESM3_TINY_DEV, distributed_model=False, load_function_decoder=False
|
||||
)
|
||||
assert isinstance(model, Esm3v2)
|
||||
client = ESM3RemoteModelInferenceClient(
|
||||
model,
|
||||
tokenizers=model.tokenizers,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "esm"
|
||||
version = "3.1.4"
|
||||
version = "3.1.5"
|
||||
description = "EvolutionaryScale open model repository"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
@@ -49,3 +49,14 @@ include = ["esm*"]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
esm = ["data/*"]
|
||||
|
||||
[tool.pixi.project]
|
||||
channels = ["conda-forge"]
|
||||
platforms = ["linux-64"]
|
||||
|
||||
[tool.pixi.dependencies]
|
||||
pkg-config = "*"
|
||||
cmake = "*"
|
||||
|
||||
[tool.pixi.pypi-dependencies]
|
||||
esm = { path = ".", editable = true }
|
||||
|
||||
@@ -72,6 +72,13 @@
|
||||
"client_container = ClientInitContainer()\n",
|
||||
"create_generation_ui(get_forge_client(model_name))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@@ -79,7 +86,7 @@
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"display_name": "default",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user