mirror of
https://github.com/evolutionaryscale/esm.git
synced 2026-06-04 17:14:23 +08:00
Added aliases to open model (#112)
This commit is contained in:
16
README.md
16
README.md
@@ -1,7 +1,8 @@
|
||||
# ESM3
|
||||
|
||||
[ESM3](https://www.evolutionaryscale.ai/papers/esm3-simulating-500-million-years-of-evolution-with-a-language-model) is a frontier generative model for biology, able to jointly reason across three fundamental biological properties of proteins: sequence, structure, and function. These three data modalities are represented as tracks of discrete tokens at the input and output of ESM3. You can present the model with a combination of partial inputs across the tracks, and ESM3 will provide output predictions for all the tracks.
|
||||
|
||||
ESM3 is a *generative* masked language model. You can prompt it with partial sequence, structure, and function keywords, and iteratively sample masked positions until all positions are unmasked. This iterative sampling is what the `.generate()` function does.
|
||||
ESM3 is a _generative_ masked language model. You can prompt it with partial sequence, structure, and function keywords, and iteratively sample masked positions until all positions are unmasked. This iterative sampling is what the `.generate()` function does.
|
||||
|
||||
<!---->
|
||||
<img src="_assets/esm3_diagram.png" alt="ESM3 Diagram" width="400" />
|
||||
@@ -13,7 +14,6 @@ Here we present `esm3-open-small`. With 1.4B parameters it is the smallest and f
|
||||
ESM3-open is available under a [non-commercial license](https://www.evolutionaryscale.ai/policies/community-license-agreement), reproduced under `LICENSE.md`.
|
||||
Visit our [Discussions page](https://github.com/evolutionaryscale/esm/discussions) to get in touch, provide feedback, ask questions or share your experience with ESM3!
|
||||
|
||||
|
||||
## Quickstart for ESM3-open
|
||||
|
||||
```
|
||||
@@ -61,10 +61,12 @@ Let's explore some more advanced prompting with the help of our [notebooks and s
|
||||
[<img src="https://colab.research.google.com/assets/colab-badge.svg">](https://colab.research.google.com/github/evolutionaryscale/esm/blob/main/examples/gfp_design.ipynb)
|
||||
|
||||
We also provide example scripts that show common workflows under `examples/`:
|
||||
* [local_generate.py](./examples/local_generate.py) shows how simple and elegant common tasks are: it shows folding, inverse folding and chain of thought generation, all by calling just `model.generate()` for iterative decoding.
|
||||
* [seqfun_struct.py](./examples/seqfun_struct.py) shows direct use of the model as a standard pytorch model with a simple model `forward` call.
|
||||
|
||||
- [local_generate.py](./examples/local_generate.py) shows how simple and elegant common tasks are: it shows folding, inverse folding and chain of thought generation, all by calling just `model.generate()` for iterative decoding.
|
||||
- [seqfun_struct.py](./examples/seqfun_struct.py) shows direct use of the model as a standard pytorch model with a simple model `forward` call.
|
||||
|
||||
## Forge: Access to larger ESM3 models
|
||||
|
||||
You can apply for beta access to the full family of larger and higher capability ESM3 models at [EvolutionaryScale Forge](https://forge.evolutionaryscale.ai).
|
||||
|
||||
We encourage users to interact with the Forge API through the python `esm` library instead of the command line.
|
||||
@@ -72,14 +74,16 @@ The python interface enables you to interactively load proteins, build prompts,
|
||||
with the `ESMProtein` and config classes used to interact with the local model.
|
||||
|
||||
In any example script try to replace a local `ESM3` model with a Forge API client:
|
||||
|
||||
```py
|
||||
# Instead of loading the model locally on your machine:
|
||||
model: ESM3InferenceClient = ESM3.from_pretrained("esm3_sm_open_v1").to("cuda") # or "cpu"
|
||||
# just replace the line with this:
|
||||
model: ESM3InferenceClient = esm.sdk.client("esm3-md-v1", token="<your forge token>")
|
||||
model: ESM3InferenceClient = esm.sdk.client("esm3-medium-2024-08", token="<your forge token>")
|
||||
# and now you're interfacing with the model running on our remote servers.
|
||||
...
|
||||
```
|
||||
|
||||
and the exact same code will work.
|
||||
This enables a seamless transition from smaller and faster models, to our large 98B protein language models for protein design work.
|
||||
|
||||
@@ -96,7 +100,6 @@ The core tenets of our framework are
|
||||
|
||||
With this in mind, we have performed a variety of mitigations for `esm3-sm-open-v1`, detailed in our [paper](https://www.evolutionaryscale.ai/papers/esm3-simulating-500-million-years-of-evolution-with-a-language-model)
|
||||
|
||||
|
||||
## License
|
||||
|
||||
**The Big Picture:**
|
||||
@@ -113,5 +116,4 @@ With this in mind, we have performed a variety of mitigations for `esm3-sm-open-
|
||||
|
||||
3. You **can publish, share and adapt** the EvolutionaryScale AI Model and its outputs for **non-commercial purposes** in accordance with the Community License Agreement, including a **non-commercial restriction** on the adapted model.
|
||||
|
||||
|
||||
Please read our non-commercial [Community License Agreement](https://www.evolutionaryscale.ai/policies/community-license-agreement) reproduced under [./LICENSE.md](LICENSE.md) before using ESM3.
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "3.0.4"
|
||||
__version__ = "3.0.5"
|
||||
|
||||
@@ -34,7 +34,10 @@ from esm.tokenization import (
|
||||
)
|
||||
from esm.utils import encoding
|
||||
from esm.utils.constants import esm3 as C
|
||||
from esm.utils.constants.models import ESM3_OPEN_SMALL
|
||||
from esm.utils.constants.models import (
|
||||
ESM3_OPEN_SMALL,
|
||||
normalize_model_name,
|
||||
)
|
||||
from esm.utils.decoding import decode_protein_tensor
|
||||
from esm.utils.generation import (
|
||||
_batch_forward,
|
||||
@@ -240,7 +243,8 @@ class ESM3(nn.Module, ESM3InferenceClient):
|
||||
) -> ESM3:
|
||||
from esm.pretrained import load_local_model
|
||||
|
||||
if model_name not in [ESM3_OPEN_SMALL]:
|
||||
model_name = normalize_model_name(model_name)
|
||||
if not model_name:
|
||||
raise ValueError(f"Model name {model_name} is not a valid ESM3 model name.")
|
||||
if device is None:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
@@ -254,6 +258,10 @@ class ESM3(nn.Module, ESM3InferenceClient):
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def raw_model(self):
|
||||
return self
|
||||
|
||||
def get_structure_encoder(self) -> StructureTokenEncoder:
|
||||
if self._structure_encoder is None:
|
||||
self._structure_encoder = self.structure_encoder_fn(self.device)
|
||||
|
||||
@@ -202,6 +202,7 @@ class ESMProteinTensor(ProteinType):
|
||||
|
||||
@define
|
||||
class ESMProteinError(Exception, ProteinType):
|
||||
error_code: int # Error code follows HTTP convention, i.e., 404 NotFoundError, 500 InternalError.
|
||||
error_msg: str
|
||||
|
||||
|
||||
@@ -343,3 +344,8 @@ class ESM3InferenceClient(ABC):
|
||||
# This is the way for power users to run ESM3. We hope to design this in a way to enable high throughput
|
||||
# inference, as well as arbitrary chain-of-though invocations of ESM3.
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def raw_model(self):
|
||||
# Get underlying esm3 model of an inference client.
|
||||
raise NotImplementedError
|
||||
|
||||
114
esm/sdk/forge.py
114
esm/sdk/forge.py
@@ -1,9 +1,16 @@
|
||||
import asyncio
|
||||
from functools import wraps
|
||||
from typing import Sequence
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_result,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from esm.sdk.api import (
|
||||
ESM3InferenceClient,
|
||||
@@ -30,6 +37,20 @@ def _list_to_function_annotations(l) -> list[FunctionAnnotation] | None:
|
||||
return [FunctionAnnotation(*t) for t in l]
|
||||
|
||||
|
||||
def retry_if_specific_error(exception):
|
||||
"""
|
||||
We only retry on specific errors.
|
||||
Currently we retry for 502 (bad gateway) and 429 (rate limit)
|
||||
"""
|
||||
return isinstance(exception, ESMProteinError) and exception.error_code in {429, 502}
|
||||
|
||||
|
||||
def log_retry_attempt(retry_state):
|
||||
print(
|
||||
f"Retrying... Attempt {retry_state.attempt_number} after {retry_state.next_action.sleep}s due to: {retry_state.outcome.result()}"
|
||||
)
|
||||
|
||||
|
||||
class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -37,24 +58,58 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
||||
url: str = "https://forge.evolutionaryscale.ai",
|
||||
token: str = "",
|
||||
request_timeout: int | None = None,
|
||||
min_retry_wait: int = 1,
|
||||
max_retry_wait: int = 10,
|
||||
max_retry_attempts: int = 5,
|
||||
):
|
||||
if token == "":
|
||||
raise RuntimeError(
|
||||
"Please provide a token to connect to Forge via token=YOUR_API_TOKEN_HERE"
|
||||
)
|
||||
self.model = model
|
||||
self.model = model # Name of the model to run.
|
||||
self.url = url
|
||||
self.token = token
|
||||
self.headers = {"Authorization": f"Bearer {self.token}"}
|
||||
self.request_timeout = request_timeout
|
||||
self.min_retry_wait = min_retry_wait
|
||||
self.max_retry_wait = max_retry_wait
|
||||
self.max_retry_attempts = max_retry_attempts
|
||||
|
||||
@staticmethod
|
||||
def retry_decorator(func):
|
||||
"""
|
||||
A static method that returns a retry decorator. This decorator uses the
|
||||
instance's retry settings.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(instance, *args, **kwargs):
|
||||
retry_decorator = retry(
|
||||
retry=retry_if_result(retry_if_specific_error),
|
||||
wait=wait_exponential(
|
||||
multiplier=1,
|
||||
min=instance.min_retry_wait,
|
||||
max=instance.max_retry_wait,
|
||||
),
|
||||
stop=stop_after_attempt(instance.max_retry_attempts),
|
||||
before_sleep=log_retry_attempt,
|
||||
)
|
||||
# Apply the retry decorator to the function
|
||||
return retry_decorator(func)(instance, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@retry_decorator
|
||||
def generate(self, input: ProteinType, config: GenerationConfig) -> ProteinType:
|
||||
if isinstance(input, ESMProtein):
|
||||
output = self.__generate_protein(input, config)
|
||||
elif isinstance(input, ESMProteinTensor):
|
||||
output = self.__generate_protein_tensor(input, config)
|
||||
else:
|
||||
return ESMProteinError(error_msg=f"Unknown input type {type(input)}")
|
||||
return ESMProteinError(
|
||||
error_code=500,
|
||||
error_msg=f"Unknown input type {type(input)}",
|
||||
)
|
||||
|
||||
if (
|
||||
isinstance(output, ESMProtein)
|
||||
@@ -72,7 +127,7 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
||||
return output
|
||||
|
||||
def batch_generate(
|
||||
self, inputs: list[ProteinType], configs: list[GenerationConfig]
|
||||
self, inputs: Sequence[ProteinType], configs: Sequence[GenerationConfig]
|
||||
) -> Sequence[ProteinType]:
|
||||
"""Forge supports auto-batching. So batch_generate() for the Forge client
|
||||
is as simple as running a collection of generate() in parallel using asyncio.
|
||||
@@ -88,10 +143,12 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
||||
|
||||
results = loop.run_until_complete(_async_generate())
|
||||
|
||||
return [
|
||||
r if not isinstance(r, BaseException) else ESMProteinError(str(r))
|
||||
for r in results
|
||||
]
|
||||
def _capture_exception(r):
|
||||
if isinstance(r, BaseException) and not isinstance(r, ESMProteinError):
|
||||
return ESMProteinError(500, str(r))
|
||||
return r
|
||||
|
||||
return [_capture_exception(r) for r in results]
|
||||
|
||||
def __generate_protein(
|
||||
self,
|
||||
@@ -119,8 +176,8 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
||||
}
|
||||
try:
|
||||
data = self._post("generate", request, input.potential_sequence_of_concern)
|
||||
except RuntimeError as e:
|
||||
return ESMProteinError(error_msg=str(e))
|
||||
except ESMProteinError as e:
|
||||
return e
|
||||
|
||||
return ESMProtein(
|
||||
sequence=data["outputs"]["sequence"],
|
||||
@@ -166,8 +223,8 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
||||
data = self._post(
|
||||
"generate_tensor", request, input.potential_sequence_of_concern
|
||||
)
|
||||
except RuntimeError as e:
|
||||
return ESMProteinError(error_msg=str(e))
|
||||
except ESMProteinError as e:
|
||||
return e
|
||||
|
||||
def _field_to_tensor(field, convert_none_to_nan: bool = False):
|
||||
if field not in data["outputs"]:
|
||||
@@ -188,6 +245,7 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
||||
|
||||
return output
|
||||
|
||||
@retry_decorator
|
||||
def forward_and_sample(
|
||||
self, input: ESMProteinTensor, sampling_configuration: SamplingConfig
|
||||
) -> ForwardAndSampleOutput | ESMProteinError:
|
||||
@@ -237,8 +295,8 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
||||
data = self._post(
|
||||
"forward_and_sample", request, input.potential_sequence_of_concern
|
||||
)
|
||||
except RuntimeError as e:
|
||||
return ESMProteinError(error_msg=str(e))
|
||||
except ESMProteinError as e:
|
||||
return e
|
||||
|
||||
def get(k, field):
|
||||
if data[k] is None:
|
||||
@@ -286,6 +344,7 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
||||
)
|
||||
return output
|
||||
|
||||
@retry_decorator
|
||||
def encode(self, input: ESMProtein) -> ESMProteinTensor | ESMProteinError:
|
||||
tracks = {}
|
||||
tracks["sequence"] = input.sequence
|
||||
@@ -299,8 +358,8 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
||||
|
||||
try:
|
||||
data = self._post("encode", request, input.potential_sequence_of_concern)
|
||||
except RuntimeError as e:
|
||||
return ESMProteinError(error_msg=str(e))
|
||||
except ESMProteinError as e:
|
||||
return e
|
||||
|
||||
return ESMProteinTensor(
|
||||
sequence=maybe_tensor(data["outputs"]["sequence"]),
|
||||
@@ -314,6 +373,7 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
||||
residue_annotations=maybe_tensor(data["outputs"]["residue_annotation"]),
|
||||
)
|
||||
|
||||
@retry_decorator
|
||||
def decode(
|
||||
self,
|
||||
input: ESMProteinTensor,
|
||||
@@ -334,8 +394,8 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
||||
|
||||
try:
|
||||
data = self._post("decode", request, input.potential_sequence_of_concern)
|
||||
except RuntimeError as e:
|
||||
return ESMProteinError(error_msg=str(e))
|
||||
except ESMProteinError as e:
|
||||
return e
|
||||
|
||||
return ESMProtein(
|
||||
sequence=data["outputs"]["sequence"],
|
||||
@@ -351,9 +411,10 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
||||
ptm=maybe_tensor(data["outputs"]["ptm"]),
|
||||
)
|
||||
|
||||
@retry_decorator
|
||||
def logits(
|
||||
self, input: ESMProteinTensor, config: LogitsConfig = LogitsConfig()
|
||||
) -> LogitsOutput:
|
||||
) -> LogitsOutput | ESMProteinError:
|
||||
# Note: using raw model forwards is discouraged because of the byte size
|
||||
# of the logits.
|
||||
# Please use forward_and_sample instead.
|
||||
@@ -381,7 +442,11 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
||||
"inputs": req,
|
||||
"logits_config": logits_config,
|
||||
}
|
||||
data = self._post("logits", request, input.potential_sequence_of_concern)
|
||||
|
||||
try:
|
||||
data = self._post("logits", request, input.potential_sequence_of_concern)
|
||||
except ESMProteinError as e:
|
||||
return e
|
||||
|
||||
def _maybe_logits(track: str):
|
||||
if "logits" in data and track in data["logits"]:
|
||||
@@ -413,7 +478,10 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
raise RuntimeError(f"Failure in {endpoint}: {response.text}")
|
||||
raise ESMProteinError(
|
||||
error_code=response.status_code,
|
||||
error_msg=f"Failure in {endpoint}: {response.text}",
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
# Nextjs puts outputs dict under "data" key.
|
||||
@@ -422,3 +490,9 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
||||
data = data["data"]
|
||||
|
||||
return data
|
||||
|
||||
@property
|
||||
def raw_model(self):
|
||||
raise NotImplementedError(
|
||||
f"Can not get underlying remote model {self.model} from a Forge client."
|
||||
)
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol
|
||||
|
||||
from esm.utils.constants.models import ESM3_OPEN_SMALL
|
||||
from esm.utils.constants.models import (
|
||||
ESM3_OPEN_SMALL,
|
||||
normalize_model_name,
|
||||
)
|
||||
|
||||
from .function_tokenizer import InterProQuantizedTokenizer
|
||||
from .residue_tokenizer import ResidueAnnotationsTokenizer
|
||||
@@ -32,7 +35,7 @@ class TokenizerCollection:
|
||||
|
||||
|
||||
def get_model_tokenizers(model: str = ESM3_OPEN_SMALL) -> TokenizerCollection:
|
||||
if model == ESM3_OPEN_SMALL:
|
||||
if normalize_model_name(model) == ESM3_OPEN_SMALL:
|
||||
return TokenizerCollection(
|
||||
sequence=EsmSequenceTokenizer(),
|
||||
structure=StructureTokenizer(),
|
||||
|
||||
@@ -1,5 +1,23 @@
|
||||
# Model names
|
||||
ESM3_OPEN_SMALL = "esm3_sm_open_v1"
|
||||
ESM3_OPEN_SMALL_ALIAS_1 = "esm3-small-open-2024-03"
|
||||
ESM3_OPEN_SMALL_ALIAS_2 = "esm3-sm-open-v1"
|
||||
ESM3_OPEN_SMALL_ALIAS_3 = "esm3-open"
|
||||
ESM3_STRUCTURE_ENCODER_V0 = "esm3_structure_encoder_v0"
|
||||
ESM3_STRUCTURE_DECODER_V0 = "esm3_structure_decoder_v0"
|
||||
ESM3_FUNCTION_DECODER_V0 = "esm3_function_decoder_v0"
|
||||
|
||||
|
||||
def model_is_locally_supported(x: str):
|
||||
return x in {
|
||||
ESM3_OPEN_SMALL,
|
||||
ESM3_OPEN_SMALL_ALIAS_1,
|
||||
ESM3_OPEN_SMALL_ALIAS_2,
|
||||
ESM3_OPEN_SMALL_ALIAS_3,
|
||||
}
|
||||
|
||||
|
||||
def normalize_model_name(x: str):
|
||||
if x in {ESM3_OPEN_SMALL_ALIAS_1, ESM3_OPEN_SMALL_ALIAS_2, ESM3_OPEN_SMALL_ALIAS_3}:
|
||||
return ESM3_OPEN_SMALL
|
||||
return x
|
||||
|
||||
@@ -299,6 +299,23 @@ def _get_iterative_sampling_mask_for_prompt_and_step(
|
||||
return where_to_sample
|
||||
|
||||
|
||||
def _get_non_special_tokens(
|
||||
protein: ESMProteinTensor, tokenizers: TokenizerCollectionProtocol
|
||||
) -> int:
|
||||
if protein.sequence is None:
|
||||
# There is no sequence to infer the number of tokens to decode.
|
||||
# So we assume the entire sequence minus bos and eos are for decoding.
|
||||
return len(protein) - 2
|
||||
|
||||
mask = torch.ones_like(protein.sequence)
|
||||
for special_token in tokenizers.sequence.special_token_ids:
|
||||
if special_token == tokenizers.sequence.mask_token_id:
|
||||
continue # MASK tokens need to be sampled.
|
||||
mask[protein.sequence == special_token] = 0
|
||||
|
||||
return int(torch.sum(mask).item())
|
||||
|
||||
|
||||
def iterative_sampling_tokens(
|
||||
client: ESM3InferenceClient,
|
||||
input_tokens: list[ESMProteinTensor],
|
||||
@@ -320,12 +337,12 @@ def iterative_sampling_tokens(
|
||||
sequence_lengths = [len(tokens) for tokens in sampled_tokens]
|
||||
# Figure out the number of tokens to be sampled for each prompt.
|
||||
total_to_sample = []
|
||||
for protein, seq_len, config in zip(sampled_tokens, sequence_lengths, configs):
|
||||
for protein, config in zip(sampled_tokens, configs):
|
||||
track = config.track
|
||||
|
||||
if getattr(protein, track) is None:
|
||||
# We need to sample the entire track.
|
||||
total_to_sample.append(seq_len - 2)
|
||||
total_to_sample.append(_get_non_special_tokens(protein, tokenizers))
|
||||
continue
|
||||
|
||||
masked = _get_masked_positions(
|
||||
@@ -368,7 +385,8 @@ def iterative_sampling_tokens(
|
||||
|
||||
if config.track in ["coordinates", "residue_annotations"]:
|
||||
errors[i] = ESMProteinError(
|
||||
error_msg=f"Iterative sampling {config.track} is not supported."
|
||||
error_code=500,
|
||||
error_msg=f"Iterative sampling {config.track} is not supported.",
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -427,7 +445,7 @@ def iterative_sampling_tokens(
|
||||
tokenizers,
|
||||
)
|
||||
except ValueError as e:
|
||||
errors[i] = ESMProteinError(error_msg=str(e))
|
||||
errors[i] = ESMProteinError(error_code=500, error_msg=str(e))
|
||||
continue
|
||||
|
||||
where_to_sample.to(input_tokens[0].device)
|
||||
@@ -547,7 +565,7 @@ def _sample_per_prompt(
|
||||
|
||||
valid_ids = (
|
||||
set(tokenizers.sasa.all_token_ids)
|
||||
- set(tokenizer.special_token_ids)
|
||||
- set(tokenizers.sasa.special_token_ids)
|
||||
- set(config.invalid_ids)
|
||||
)
|
||||
sasa_logits = logits_output.logits.sasa
|
||||
@@ -568,7 +586,8 @@ def _sample_per_prompt(
|
||||
|
||||
# Sample function and residue annotations separately
|
||||
config = getattr(sampling_config, "function")
|
||||
if config is None:
|
||||
function_logits = getattr(logits_output.logits, "function")
|
||||
if config is None or function_logits is None:
|
||||
tokens_dir["function"] = maybe_clone(getattr(protein, "function"))
|
||||
tokens_dir["residue_annotations"] = maybe_clone(
|
||||
getattr(protein, "residue_annotations")
|
||||
@@ -580,7 +599,7 @@ def _sample_per_prompt(
|
||||
sampling_metadata = _sample_function_track(
|
||||
tokenizers.function,
|
||||
tokens=getattr(protein, "function"),
|
||||
logits=getattr(logits_output.logits, "function"),
|
||||
logits=function_logits,
|
||||
sampling_track_config=config,
|
||||
)
|
||||
tokens_dir["function"] = sampling_metadata.pop("sampled_tokens") # (L, D)
|
||||
|
||||
116
esm/utils/generation_test.py
Normal file
116
esm/utils/generation_test.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from esm.sdk.api import (
|
||||
ESMProtein,
|
||||
ESMProteinError,
|
||||
ESMProteinTensor,
|
||||
GenerationConfig,
|
||||
)
|
||||
from evolutionaryscale.utils.env import ModelName
|
||||
from evolutionaryscale.utils.remote_inference.api_v1 import (
|
||||
ESM3RemoteModelInferenceClient,
|
||||
)
|
||||
from projects.forge.fastapi.utils.model import _load_esm3
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def esm3_remote_inference_client():
|
||||
model = _load_esm3(ModelName.ESM3_TINY_DEV, distributed_model=False)
|
||||
client = ESM3RemoteModelInferenceClient(
|
||||
model,
|
||||
tokenizers=model.tokenizers,
|
||||
device=torch.device("cuda"),
|
||||
enable_batched_runner=False,
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
def test_chain_break_tokens(esm3_remote_inference_client):
|
||||
tokenizer = esm3_remote_inference_client.tokenizers.sequence
|
||||
# 3 separate chains with 2 chainbreak tokens.
|
||||
sequence_with_chain_breaks = torch.tensor(
|
||||
[
|
||||
tokenizer.bos_token_id,
|
||||
20,
|
||||
20,
|
||||
20,
|
||||
20,
|
||||
tokenizer.chain_break_token_id,
|
||||
21,
|
||||
21,
|
||||
21,
|
||||
tokenizer.chain_break_token_id,
|
||||
22,
|
||||
22,
|
||||
22,
|
||||
tokenizer.eos_token_id,
|
||||
]
|
||||
)
|
||||
protein = esm3_remote_inference_client.generate(
|
||||
ESMProteinTensor(sequence=sequence_with_chain_breaks),
|
||||
# There are 10 tokens that actually need to be sampled.
|
||||
GenerationConfig(track="structure", num_steps=10),
|
||||
)
|
||||
|
||||
assert isinstance(protein, ESMProteinTensor)
|
||||
assert protein.structure is not None
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
def test_num_decoding_steps_more_than_mask_tokens_fails(esm3_remote_inference_client):
|
||||
protein = esm3_remote_inference_client.generate(
|
||||
ESMProtein(sequence="CDEFG"), # sequence of 5.
|
||||
GenerationConfig(track="structure", num_steps=10), # use 10 decoding steps.
|
||||
)
|
||||
# Can't specify more decoding steps than masks available.
|
||||
assert isinstance(protein, ESMProteinError)
|
||||
assert protein.error_code == 500
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
def test_encode_chainbreak_token(esm3_remote_inference_client):
|
||||
protein = esm3_remote_inference_client.encode(
|
||||
ESMProtein(sequence="MSTNP|KPQKK"),
|
||||
)
|
||||
# Can't specify more decoding steps than masks available.
|
||||
assert isinstance(protein, ESMProteinTensor)
|
||||
assert protein.sequence is not None
|
||||
assert (
|
||||
protein.sequence[6]
|
||||
== esm3_remote_inference_client.tokenizers.sequence.chain_break_token_id
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
def test_generation_with_chainbreak_token(esm3_remote_inference_client):
|
||||
chainbreak_sequence = torch.tensor(
|
||||
[
|
||||
esm3_remote_inference_client.tokenizers.sequence.bos_token_id,
|
||||
20,
|
||||
8,
|
||||
11,
|
||||
17,
|
||||
14,
|
||||
esm3_remote_inference_client.tokenizers.sequence.chain_break_token_id,
|
||||
15,
|
||||
14,
|
||||
16,
|
||||
15,
|
||||
15,
|
||||
esm3_remote_inference_client.tokenizers.sequence.eos_token_id,
|
||||
]
|
||||
)
|
||||
|
||||
protein = esm3_remote_inference_client.generate(
|
||||
ESMProteinTensor(sequence=chainbreak_sequence),
|
||||
GenerationConfig(track="structure", num_steps=1),
|
||||
)
|
||||
# Can't specify more decoding steps than masks available.
|
||||
assert isinstance(protein, ESMProteinTensor)
|
||||
assert protein.structure is not None
|
||||
assert (
|
||||
protein.structure[6]
|
||||
== esm3_remote_inference_client.tokenizers.structure.chain_break_token_id
|
||||
)
|
||||
@@ -163,6 +163,10 @@ def sample_logits(
|
||||
logits is shape (..., vocab_size)
|
||||
temperature is broadcastable to (...)
|
||||
"""
|
||||
if len(valid_ids) == 0:
|
||||
raise ValueError(
|
||||
"Can not sample logits if there are no valid ids to sample from."
|
||||
)
|
||||
|
||||
if top_p < 1.0:
|
||||
logits = top_p_logits(logits, top_p=top_p)
|
||||
@@ -181,7 +185,7 @@ def sample_logits(
|
||||
|
||||
if torch.all(temperature == 0):
|
||||
ids = logits.argmax(-1)
|
||||
return ids
|
||||
return ids.reshape(*batch_dims)
|
||||
|
||||
assert not torch.any(temperature == 0), "Partial temperature 0 not supported."
|
||||
|
||||
|
||||
40
esm/utils/sampling_test.py
Normal file
40
esm/utils/sampling_test.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from esm.utils.sampling import sample_logits
|
||||
|
||||
|
||||
def test_sample_logits():
|
||||
# batched input. temperature != 0.0.
|
||||
sampled = sample_logits(
|
||||
logits=torch.randn((64, 8, 4096)), temperature=0.8, valid_ids=list(range(4096))
|
||||
)
|
||||
assert sampled.shape == (64, 8)
|
||||
|
||||
# batched input. temperature == 0.0.
|
||||
sampled = sample_logits(
|
||||
logits=torch.randn((64, 8, 4096)), temperature=0.0, valid_ids=list(range(4096))
|
||||
)
|
||||
assert sampled.shape == (64, 8)
|
||||
|
||||
# non-batched input. temperature != 0.0.
|
||||
sampled = sample_logits(
|
||||
logits=torch.randn((8, 4096)), temperature=0.8, valid_ids=list(range(4096))
|
||||
)
|
||||
assert sampled.shape == (8,)
|
||||
|
||||
# non-batched input. temperature == 0.0.
|
||||
sampled = sample_logits(
|
||||
logits=torch.randn((8, 4096)), temperature=0.0, valid_ids=list(range(4096))
|
||||
)
|
||||
assert sampled.shape == (8,)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
sampled = sample_logits(
|
||||
logits=torch.randn((8, 4096)),
|
||||
temperature=0.0,
|
||||
valid_ids=[],
|
||||
)
|
||||
|
||||
|
||||
test_sample_logits()
|
||||
@@ -157,7 +157,10 @@ def create_esm3_generation_launcher(
|
||||
model=model_name.value, token=forge_token
|
||||
)
|
||||
elif isinstance(client, ESM3):
|
||||
if model_name.value != models.ESM3_OPEN_SMALL:
|
||||
if (
|
||||
models.normalize_model_name(model_name.value)
|
||||
!= models.ESM3_OPEN_SMALL
|
||||
):
|
||||
raise ValueError(
|
||||
f"Model name {model_name.value} does not match the client model {models.ESM3_OPEN_SMALL}"
|
||||
)
|
||||
|
||||
@@ -92,7 +92,7 @@ def create_login_ui(client_container: ClientInitContainer):
|
||||
layout={"width": "50%"},
|
||||
)
|
||||
forge_model = widgets.Text(
|
||||
value="esm3-md-v1",
|
||||
value="esm3-medium-2024-08",
|
||||
description="Model Name:",
|
||||
disabled=False,
|
||||
layout={"width": "50%"},
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -121,7 +121,7 @@
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"model = client(\n",
|
||||
" model=\"esm3-md-alpha1\",\n",
|
||||
" model=\"esm3-medium-2024-03\",\n",
|
||||
" url=\"https://forge.evolutionaryscale.ai\",\n",
|
||||
" token=token,\n",
|
||||
")"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "esm"
|
||||
version = "3.0.4"
|
||||
version = "3.0.5"
|
||||
description = "EvolutionaryScale open model repository"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
Reference in New Issue
Block a user