pulled into changes from internal (#198)

Co-authored-by: chetan <chetan@evolutionaryscale.ai>
This commit is contained in:
Chetan Mishra
2025-02-26 15:52:25 -05:00
committed by GitHub
parent 1b3e08025b
commit 82b1431b7b
13 changed files with 321 additions and 70 deletions

View File

@@ -154,6 +154,36 @@ print(logits_output.logits, logits_output.embeddings)
Remember to replace `<your forge token>` with your actual Forge access token.
### Forge Batch Executor
For jobs that require processing multiple inputs, the Forge Batch Executor provides a streamlined and way to execute them concurrently and efficiently while respecting rate limits and adapting to request latency.
```py
from esm.sdk.forge import ESM3ForgeInferenceClient
from esm.sdk.api import ESMProtein, LogitsConfig
from esm.sdk import batch_executor
def embed_sequence(client: ESM3ForgeInferenceClient, sequence: str) -> LogitsOutput:
protein = ESMProtein(sequence=sequence)
protein_tensor = client.encode(protein)
if isinstance(protein_tensor, ESMProteinError):
raise protein_tensor
output = client.logits(protein_tensor, LogitsConfig(sequence=True, return_embeddings=True))
return output
sequences = ["A", "AA", "AAA"]
client = ESM3ForgeInferenceClient(model="esmc-6b-2024-12", url="https://forge.evolutionaryscale.ai", token="<your forge token>")
# Usage Example:
# To execute a batch job, wrap your function inside the batch executor context manager.
# Syntax:
# with batch_executor() as executor:
# outputs = executor.execute_batch(user_func=<your_function>, **kwargs)
with batch_executor() as executor:
outputs = executor.execute_batch(user_func=embed_sequence, model=client, sequence=sequences)
```
### ESM C via SageMaker for Commercial Use <a name="esm-c-sagemaker"></a>
ESM C models are also available on Amazon SageMaker under the [Cambrian Inference Clickthrough License Agreement](https://www.evolutionaryscale.ai/policies/cambrian-inference-clickthrough-license-agreement).

View File

@@ -1,7 +1,5 @@
import os
import torch
from esm.models.esm3 import ESM3
from esm.sdk import client
from esm.sdk.api import (
@@ -9,7 +7,6 @@ from esm.sdk.api import (
ESMProtein,
ESMProteinError,
ESMProteinTensor,
ForwardAndSampleOutput,
GenerationConfig,
LogitsConfig,
LogitsOutput,
@@ -197,7 +194,6 @@ def main(client: ESM3InferenceClient):
assert isinstance(p, ESMProtein), f"ESMProtein was expected but got {p}"
if __name__ == "__main__":
if os.environ.get("ESM_API_KEY", ""):
print("ESM_API_KEY found. Trying to use model from Forge...")

View File

@@ -53,7 +53,7 @@ def fold(
esm3_client_folded_protein = esm3_client.generate(protein, config)
assert isinstance(
esm3_client_folded_protein, ESMProtein
), f"Using ESM3 client, ESMProtein was expected but got {protein}"
), f"Using ESM3 client, ESMProtein was expected but got {esm3_client_folded_protein}"
# Folding with folding client
sequence_structure_client_folded_protein = sequence_structure_client.fold(
@@ -85,7 +85,7 @@ def inverse_fold(
)
assert isinstance(
esm3_client_inv_folded_protein, ESMProtein
), f"Using ESM3 client, ESMProtein was expected but got {protein}"
), f"Using ESM3 client, ESMProtein was expected but got {esm3_client_inv_folded_protein}"
# Inverse Folding with inverse folding client
sequence_structure_client_inv_folded_protein = (

View File

@@ -53,7 +53,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -83,7 +83,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"#### `sequence`\n",
"## `sequence`\n",
"The `sequence` track contains a sequence of 1-letter representation of the amino acids in the protein:"
]
},
@@ -100,7 +100,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"#### `coordinates`\n",
"## `coordinates`\n",
"\n",
"\n",
"`coordinates` contains the 3D coordinates of atoms in the protein. It contains a tensor of shape `(n_residues, 37, 3)`, where \n",
@@ -139,7 +139,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@@ -206,7 +206,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
@@ -249,7 +249,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
@@ -438,7 +438,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"#### `function_annotations`\n",
"## `function_annotations`\n",
"\n",
"An `ESMProtein` also contains function annotations derived from [InterPro](https://www.ebi.ac.uk/interpro/). Annotations directly from InterPro contain information about the following [entry types](https://interpro-documentation.readthedocs.io/en/latest/faq.html#what-are-entry-types):\n",
"* Family\n",
@@ -450,7 +450,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
@@ -475,7 +475,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
@@ -543,14 +543,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"When using our `ESM3` model, we recommend you use keyword annotations, which are keywords in the description of the InterPro entry and associated Gene Ontology terms from [InterPro2GO](https://www.ebi.ac.uk/GOA/InterPro2GO). For instance, for the InterPro entry [IPR011992](https://www.ebi.ac.uk/interpro/entry/InterPro/IPR011992/), the keywords are \"domain pair\", \"hand domain\", \"ef hand\", \"pair\", and \"ef\". For more details regarding how the keywords were computed, please refer to our preprint.\n",
"When using our `ESM3` model, we recommend you use keyword annotations, which are keywords in the description of the InterPro entry and associated Gene Ontology terms from [InterPro2GO](https://www.ebi.ac.uk/GOA/InterPro2GO). For instance, for the InterPro entry [IPR011992](https://www.ebi.ac.uk/interpro/entry/InterPro/IPR011992/), the keywords are \"domain pair\", \"hand domain\", \"ef hand\", \"pair\", and \"ef\". For more details regarding how the keywords were computed, please refer to our [preprint](https://www.biorxiv.org/content/10.1101/2024.07.01.600583v1.full.pdf).\n",
"\n",
"Practically, we can derive keyword annotations from the InterPro annotations with the function below. Each InterPro annotation corresponds to multiple keyword annotation covering the same range."
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
@@ -608,7 +608,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"#### `sasa`"
"## `sasa`"
]
},
{
@@ -620,7 +620,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
@@ -707,7 +707,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
@@ -772,7 +772,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [

File diff suppressed because one or more lines are too long

View File

@@ -48,6 +48,13 @@
"# Set up the client to Forge\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories."
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -450,7 +457,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we can visualize the generated structure using `py3Dmol`. We'll visualize the generated structure (right) alongside the original structure (left) from which the motif was drawn. The helix-coil-helix region in the original structure is colored in blue and the shortened region in the generated structure is colored in pink.\n"
"Now, we can visualize the generated structure using `py3Dmol`. We'll visualize the generated structure (right) alongside the original structure (left) from which the motif was drawn. The helix-coil-helix region in the original structure is colored in blue and the shortened region in the generated structure is colored in red.\n"
]
},
{
@@ -478,7 +485,7 @@
")\n",
"view.addStyle(\n",
" {\"resi\": (np.arange(helix_region[0], helix_region[0] + 45) + 1).tolist()},\n",
" {\"cartoon\": {\"color\": \"pink\"}},\n",
" {\"cartoon\": {\"color\": \"red\"}},\n",
" viewer=(0, 1),\n",
")\n",
"view.zoomTo()\n",

View File

@@ -1,2 +1,2 @@
__version__ = "3.1.3"
__version__ = "3.1.4"

View File

@@ -1,6 +1,7 @@
import os
from esm.sdk.forge import ESM3ForgeInferenceClient
from esm.utils.forge_context_manager import ForgeBatchExecutor
# Note: please do not import ESM3SageMakerClient here since that requires AWS SDK.
@@ -20,3 +21,16 @@ def client(
Default is wait indefinitely.
"""
return ESM3ForgeInferenceClient(model, url, token, request_timeout)
def batch_executor(max_attempts: int = 10):
"""
Args:
max_attempts: Maximum number of attempts to make before giving up.
Usage:
with batch_executor() as executor:
for inputs in input_batches:
executor.submit(fn, inputs)
"""
return ForgeBatchExecutor(max_attempts)

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
from abc import ABC
from copy import deepcopy
from typing import List, Sequence
import attr
@@ -108,7 +109,9 @@ class ESMProtein(ProteinType):
secondary_structure=None,
sasa=None,
function_annotations=None,
coordinates=torch.tensor(protein_complex.atom37_positions),
coordinates=torch.tensor(
protein_complex.atom37_positions, dtype=torch.float32
),
)
def to_pdb(self, pdb_path: PathOrBuffer) -> None:
@@ -164,6 +167,10 @@ class ESMProtein(ProteinType):
pred_chains.append(pred_chain)
return ProteinComplex.from_chains(pred_chains)
def copy(self) -> "ESMProtein":
"""Create a deep copy of the ESMProtein instance."""
return deepcopy(self)
@define
class ESMProteinTensor(ProteinType):
@@ -244,6 +251,10 @@ class ESMProteinTensor(ProteinType):
).to(device),
)
def copy(self) -> ESMProteinTensor:
"""Create a deep copy of the ESMProteinTensor instance."""
return deepcopy(self)
@define
class ESMProteinError(Exception, ProteinType):
@@ -338,6 +349,11 @@ class ForwardTrackData:
class LogitsConfig:
# Logits.
sequence: bool = False
# Note that getting logits for tracks other than sequence
# are not supported by Forge today, due to their impractical
# data sizes.
# These are of course supported when running local OSS models.
structure: bool = False
secondary_structure: bool = False
sasa: bool = False

View File

@@ -1,5 +1,6 @@
import base64
from concurrent.futures import ThreadPoolExecutor
from contextvars import ContextVar
from functools import wraps
from typing import Sequence
from urllib.parse import urljoin
@@ -31,6 +32,8 @@ from esm.utils.misc import (
from esm.utils.sampling import validate_sampling_config
from esm.utils.types import FunctionAnnotation
skip_retries_var = ContextVar("skip_retries", default=False)
def _list_to_function_annotations(l) -> list[FunctionAnnotation] | None:
if l is None or len(l) <= 0:
@@ -59,7 +62,7 @@ def log_retry_attempt(retry_state):
def _validate_protein_tensor_input(input):
if not isinstance(input, ESMProteinTensor):
raise ValueError(
"Input must be an ESMProteinTensor instance. "
f"Input must be an ESMProteinTensor instance, but received {type(input)} instead. "
"Use encode() API to encode an ESMProtein into ESMProteinTensor."
)
@@ -186,6 +189,8 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
@wraps(func)
def wrapper(instance, *args, **kwargs):
if skip_retries_var.get():
return func(instance, *args, **kwargs)
retry_decorator = retry(
retry=retry_if_result(retry_if_specific_error),
wait=wait_exponential(

View File

@@ -0,0 +1,138 @@
import threading
from collections import deque
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
from contextvars import copy_context
from typing import Any, Callable, Dict, List
from tqdm import tqdm
from esm.sdk.api import ESMProteinError
from esm.sdk.forge import (
retry_if_specific_error,
skip_retries_var,
)
TQDM_BAR_FORMAT = (
"{desc:<12}{percentage:3.0f}%|{bar:24}| {n_fmt}/{total_fmt} "
"[Elapsed: {elapsed} | Remaining: {remaining}] {postfix}"
)
class AIMDRateLimiter:
"""Rate limiter with AIMD (Additive Increase/Multiplicative Decrease) control."""
def __init__(
self,
initial_concurrency: int = 32,
min_concurrency: int = 1,
max_concurrency: int = 512,
step_up: int = 1,
):
self.concurrency = initial_concurrency
self.min_concurrency = min_concurrency
self.max_concurrency = max_concurrency
self.step_up = step_up
self._lock = threading.Lock()
def adjust_concurrency(self, error_seen: bool) -> int:
"""Update concurrency based on if an error is seen."""
with self._lock:
if error_seen:
self.concurrency = max(self.min_concurrency, self.concurrency // 2)
else:
self.concurrency = min(
self.max_concurrency, self.concurrency + self.step_up
)
return self.concurrency
class ForgeBatchExecutor:
"""Context manager for managing concurrent calls with rate limiting."""
def __init__(self, max_attempts: int = 10):
self.rate_limiter = AIMDRateLimiter()
self.max_attempts = max_attempts
self._executor = ThreadPoolExecutor(
max_workers=self.rate_limiter.max_concurrency
)
self._skip_retries_token = None
def __enter__(self):
self._skip_retries_token = skip_retries_var.set(True)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self._skip_retries_token is not None:
skip_retries_var.reset(self._skip_retries_token)
if self._executor:
self._executor.shutdown(wait=True)
def _validate_inputs(self, inputs: Dict[str, Any]) -> int:
"""Validate input lengths and return the number of tasks."""
input_lengths = [len(v) for v in inputs.values() if isinstance(v, list)]
num_inputs = max(input_lengths) if input_lengths else 1
if input_lengths and len(set(input_lengths)) > 1:
raise ValueError("All list-valued arguments must have the same length")
return num_inputs
def execute_batch(self, user_func: Callable, **kwargs: Any) -> List[Any]:
"""Call the endpoint with batched inputs, managing concurrency and retries."""
num_tasks = self._validate_inputs(kwargs)
# Initialize task queue with (task_index, attempt) tuples.
task_queue = deque([(i, 1) for i in range(num_tasks)])
results = [None] * num_tasks
running_futures = {}
success_count = 0
fail_count = 0
retry_count = 0
with tqdm(
total=num_tasks, desc="Processing", bar_format=TQDM_BAR_FORMAT, unit="task"
) as pbar:
while task_queue or running_futures:
current_limit = self.rate_limiter.concurrency
while task_queue and len(running_futures) < current_limit:
idx, attempt = task_queue.popleft()
call_kwargs = {
k: v[idx] if isinstance(v, list) else v
for k, v in kwargs.items()
}
ctx = copy_context()
future = self._executor.submit(ctx.run, user_func, **call_kwargs)
running_futures[future] = (idx, attempt)
done, _ = wait(
running_futures.keys(), return_when=FIRST_COMPLETED, timeout=1
)
error_seen = False
for future in done:
idx, attempt = running_futures.pop(future)
try:
result = future.result()
if isinstance(result, ESMProteinError):
raise result
results[idx] = result
success_count += 1
pbar.update(1)
except Exception as e:
if retry_if_specific_error(e) and attempt < self.max_attempts:
task_queue.append((idx, attempt + 1))
# Only scale concurrency if hit rate limit errors.
if isinstance(e, ESMProteinError) and e.error_code == 429:
error_seen = True
retry_count += 1
pbar.update(0)
else:
results[idx] = e # type: ignore
fail_count += 1
pbar.update(0)
self.rate_limiter.adjust_concurrency(error_seen)
pbar.set_postfix_str(
f"Success={success_count} Fail={fail_count} Retry={retry_count}"
)
return results

View File

@@ -1,6 +1,6 @@
[project]
name = "esm"
version = "3.1.3"
version = "3.1.4"
description = "EvolutionaryScale open model repository"
readme = "README.md"
requires-python = ">=3.10"

View File

@@ -8,7 +8,7 @@
"source": [
"# Generation UI\n",
"\n",
"This is the most flexible notebook for generating protein sequences using the ESM3 model."
"This is the most flexible notebook for generating protein sequences using the ESM3 model.\n"
]
},
{
@@ -65,15 +65,12 @@
"# @title Create Generation UI\n",
"# @markdown If running on Google colab, it is recommended to use the light theme and select the \"View output fullscreen\" option in the cell toolbar for the best experience\n",
"\n",
"from functools import partial\n",
"\n",
"from esm.widgets.utils.clients import get_forge_client\n",
"from esm.widgets.utils.types import ClientInitContainer\n",
"from esm.widgets.views.generation import create_generation_ui\n",
"\n",
"client_container = ClientInitContainer()\n",
"client_container.client_init_callback = partial(get_forge_client, model_name)\n",
"create_generation_ui(client_container)"
"create_generation_ui(get_forge_client(model_name))"
]
}
],