mirror of
https://github.com/evolutionaryscale/esm.git
synced 2026-06-04 17:14:23 +08:00
3.0.3 (#80)
This commit is contained in:
@@ -1 +1 @@
|
||||
__version__ = "3.0.2"
|
||||
__version__ = "3.0.3"
|
||||
|
||||
@@ -2,6 +2,8 @@ import os
|
||||
|
||||
from esm.sdk.forge import ESM3ForgeInferenceClient
|
||||
|
||||
# Note: please do not import ESM3SageMakerClient here since that requires AWS SDK.
|
||||
|
||||
|
||||
def client(
|
||||
model="esm3-sm-open-v1",
|
||||
|
||||
@@ -32,7 +32,7 @@ class ESMProtein(ProteinType):
|
||||
# Tracks
|
||||
sequence: str | None = None
|
||||
secondary_structure: str | None = None
|
||||
sasa: list[int | float | None] | None = None
|
||||
sasa: list[float | None] | None = None
|
||||
function_annotations: list[FunctionAnnotation] | None = None
|
||||
coordinates: torch.Tensor | None = None
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
from typing import Sequence
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
import torch
|
||||
@@ -117,7 +118,7 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
||||
"condition_on_coordinates_only": config.condition_on_coordinates_only,
|
||||
}
|
||||
try:
|
||||
data = self.__post("generate", request, input.potential_sequence_of_concern)
|
||||
data = self._post("generate", request, input.potential_sequence_of_concern)
|
||||
except RuntimeError as e:
|
||||
return ESMProteinError(error_msg=str(e))
|
||||
|
||||
@@ -162,7 +163,7 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
||||
}
|
||||
|
||||
try:
|
||||
data = self.__post(
|
||||
data = self._post(
|
||||
"generate_tensor", request, input.potential_sequence_of_concern
|
||||
)
|
||||
except RuntimeError as e:
|
||||
@@ -233,7 +234,7 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
||||
"embedding_config": embedding_config,
|
||||
}
|
||||
try:
|
||||
data = self.__post(
|
||||
data = self._post(
|
||||
"forward_and_sample", request, input.potential_sequence_of_concern
|
||||
)
|
||||
except RuntimeError as e:
|
||||
@@ -297,7 +298,7 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
||||
request = {"inputs": tracks, "model": self.model}
|
||||
|
||||
try:
|
||||
data = self.__post("encode", request, input.potential_sequence_of_concern)
|
||||
data = self._post("encode", request, input.potential_sequence_of_concern)
|
||||
except RuntimeError as e:
|
||||
return ESMProteinError(error_msg=str(e))
|
||||
|
||||
@@ -332,7 +333,7 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
||||
}
|
||||
|
||||
try:
|
||||
data = self.__post("decode", request, input.potential_sequence_of_concern)
|
||||
data = self._post("decode", request, input.potential_sequence_of_concern)
|
||||
except RuntimeError as e:
|
||||
return ESMProteinError(error_msg=str(e))
|
||||
|
||||
@@ -380,7 +381,7 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
||||
"inputs": req,
|
||||
"logits_config": logits_config,
|
||||
}
|
||||
data = self.__post("logits", request, input.potential_sequence_of_concern)
|
||||
data = self._post("logits", request, input.potential_sequence_of_concern)
|
||||
|
||||
def _maybe_logits(track: str):
|
||||
if "logits" in data and track in data["logits"]:
|
||||
@@ -401,11 +402,11 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
||||
|
||||
return output
|
||||
|
||||
def __post(self, endpoint, request, potential_sequence_of_concern):
|
||||
def _post(self, endpoint, request, potential_sequence_of_concern):
|
||||
request["potential_sequence_of_concern"] = potential_sequence_of_concern
|
||||
|
||||
response = requests.post(
|
||||
f"{self.url}/api/v1/{endpoint}",
|
||||
urljoin(self.url, f"/api/v1/{endpoint}"),
|
||||
json=request,
|
||||
headers=self.headers,
|
||||
timeout=self.request_timeout,
|
||||
|
||||
56
esm/sdk/sagemaker.py
Normal file
56
esm/sdk/sagemaker.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import json
|
||||
|
||||
import boto3
|
||||
|
||||
from esm.sdk.forge import ESM3ForgeInferenceClient
|
||||
|
||||
|
||||
class ESM3SageMakerClient(ESM3ForgeInferenceClient):
|
||||
def __init__(self, endpoint_name: str, model: str):
|
||||
"""ESM3 client that talks to a SageMaker endpoint.
|
||||
|
||||
Args:
|
||||
endpoint_name: Name of the SageMaker endpoint.
|
||||
model: Name of the ESM3 model.
|
||||
"""
|
||||
# Dummy URL and token to make ESM3ForgeInferenceClient happy.
|
||||
super().__init__(model=model, url="", token="dummy")
|
||||
|
||||
self._endpoint_name = endpoint_name
|
||||
self._model = model
|
||||
|
||||
self._client = boto3.client(service_name="sagemaker-runtime")
|
||||
|
||||
def _post(self, endpoint, request, potential_sequence_of_concern):
|
||||
request["potential_sequence_of_concern"] = potential_sequence_of_concern
|
||||
|
||||
invocations_request = {
|
||||
# Duplicate these fields at the top level to make Forge requests consistent.
|
||||
"model": request["model"],
|
||||
"request_id": "", # Forge specific field.
|
||||
"user_id": "", # Forge specific field.
|
||||
# Invocation data bits.
|
||||
"api_ver": "v1", # Must be v1 right now.
|
||||
"endpoint": endpoint,
|
||||
# Wrapped request.
|
||||
endpoint: request,
|
||||
}
|
||||
|
||||
try:
|
||||
response = self._client.invoke_endpoint(
|
||||
EndpointName=self._endpoint_name,
|
||||
ContentType="application/json",
|
||||
Body=json.dumps(invocations_request),
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failure in {endpoint}: {e}")
|
||||
|
||||
data = json.loads(response["Body"].read().decode())
|
||||
|
||||
# Response must match request.
|
||||
assert data["endpoint"] == endpoint
|
||||
|
||||
# Get the actual responses under the endpoint key.
|
||||
data = data[endpoint]
|
||||
|
||||
return data
|
||||
@@ -200,7 +200,18 @@ def decode_sasa(
|
||||
if sasa_tokens[-1] != 0:
|
||||
raise ValueError("SASA does not end with 0 corresponding to EOS token")
|
||||
sasa_tokens = sasa_tokens[1:-1]
|
||||
sasa = sasa_tokens.tolist()
|
||||
if sasa_tokens.dtype in [
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
torch.long,
|
||||
]:
|
||||
# Decode if int
|
||||
sasa = sasa_tokenizer.decode_float(sasa_tokens)
|
||||
else:
|
||||
# If already float, just convert to list
|
||||
sasa = sasa_tokens.tolist()
|
||||
return sasa
|
||||
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ from esm.utils.misc import stack_variable_length_tensors
|
||||
from esm.utils.noise_schedules import NOISE_SCHEDULE_REGISTRY
|
||||
from esm.utils.sampling import (
|
||||
_BatchedESMProteinTensor,
|
||||
get_sampling_mask,
|
||||
sample_function_logits,
|
||||
sample_logits,
|
||||
sample_residue_annotation_logits,
|
||||
@@ -108,14 +109,7 @@ def iterative_sampling_raw(
|
||||
raw_proteins: list[ESMProtein | ESMProteinError] = []
|
||||
for output_tokens in output_tokens_list:
|
||||
if isinstance(output_tokens, ESMProteinTensor):
|
||||
try:
|
||||
raw_proteins.append(client.decode(output_tokens))
|
||||
except Exception:
|
||||
# Print the input tokens so we know what is wrong.
|
||||
print("Encountered exception during decoding:")
|
||||
print(output_tokens)
|
||||
# Re-raise.
|
||||
raise
|
||||
raw_proteins.append(client.decode(output_tokens))
|
||||
elif isinstance(output_tokens, ESMProteinError):
|
||||
raw_proteins.append(output_tokens)
|
||||
else:
|
||||
@@ -214,9 +208,15 @@ def _get_masked_positions(
|
||||
track: str, tokens: torch.Tensor, mask_token_id: int
|
||||
) -> torch.Tensor:
|
||||
if track == "function":
|
||||
return torch.all(tokens == mask_token_id, dim=-1).to(tokens.device)
|
||||
mask = torch.all(tokens == mask_token_id, dim=-1).to(tokens.device)
|
||||
else:
|
||||
return tokens == mask_token_id
|
||||
mask = tokens == mask_token_id
|
||||
|
||||
# Should not sample BOS and EOS positions.
|
||||
mask[..., 0] = False
|
||||
mask[..., -1] = False
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def _get_iterative_sampling_mask_for_prompt_and_step(
|
||||
@@ -407,6 +407,7 @@ def iterative_sampling_tokens(
|
||||
per_prompt_forward_out,
|
||||
sampling_config,
|
||||
tokenizers,
|
||||
decode_sasa_tokens=False,
|
||||
)
|
||||
|
||||
# All positions sampled after _sample_per_prompt() above.
|
||||
@@ -493,14 +494,21 @@ def _sample_per_prompt(
|
||||
logits_output: LogitsOutput,
|
||||
sampling_config: SamplingConfig,
|
||||
tokenizers: TokenizerCollectionProtocol,
|
||||
decode_sasa_tokens: bool = True,
|
||||
) -> ForwardAndSampleOutput:
|
||||
assert logits_output.logits is not None
|
||||
|
||||
def maybe_clone(x: torch.Tensor | None) -> torch.Tensor | None:
|
||||
return x.clone() if x is not None else None
|
||||
|
||||
# Sampling
|
||||
tokens_dir = {}
|
||||
track_sampling_metadata_dir: dict[str, dict | None] = {}
|
||||
for track in ["sequence", "structure", "secondary_structure"]:
|
||||
integer_sampling_tracks = ["sequence", "structure", "secondary_structure"]
|
||||
if not decode_sasa_tokens:
|
||||
integer_sampling_tracks.append("sasa")
|
||||
|
||||
for track in integer_sampling_tracks:
|
||||
config = getattr(sampling_config, track)
|
||||
if config is None:
|
||||
tokens_dir[track] = maybe_clone(getattr(protein, track))
|
||||
@@ -514,21 +522,33 @@ def _sample_per_prompt(
|
||||
tokens_dir[track] = sampling_metadata.pop("sampled_tokens") # (L,)
|
||||
track_sampling_metadata_dir[track] = sampling_metadata
|
||||
|
||||
# Sample SASA seperately
|
||||
config = getattr(sampling_config, "sasa")
|
||||
track_sampling_metadata_dir["sasa"] = None
|
||||
# Sample SASA seperately (if needed)
|
||||
if decode_sasa_tokens:
|
||||
config = getattr(sampling_config, "sasa")
|
||||
track_sampling_metadata_dir["sasa"] = None
|
||||
|
||||
if config is not None:
|
||||
if config.topk_logprobs > 0:
|
||||
warn("For SASA sampling, 'topk_logprobs' is expected to be 0.")
|
||||
sasa_logits = logits_output.logits.sasa[0, ...] # type: ignore
|
||||
sasa_value = sample_sasa_logits(sasa_logits, protein.sasa[0, ...]) # type: ignore
|
||||
tokens_dir["sasa"] = sasa_value
|
||||
if config is None:
|
||||
tokens_dir["sasa"] = maybe_clone(getattr(protein, "sasa"))
|
||||
else:
|
||||
if config.topk_logprobs > 0:
|
||||
warn("For SASA sampling, 'topk_logprobs' is expected to be 0.")
|
||||
|
||||
probs = sasa_logits.softmax(dim=-1)
|
||||
entropy = -(probs * sasa_logits.log_softmax(-1)).sum(-1)
|
||||
assert logits_output.logits.sasa is not None
|
||||
assert protein.sasa is not None
|
||||
|
||||
track_sampling_metadata_dir["sasa"] = {"entropy": entropy}
|
||||
sasa_logits = logits_output.logits.sasa
|
||||
sasa_value = sample_sasa_logits(
|
||||
sasa_logits,
|
||||
protein.sasa,
|
||||
sampling_track_config=config,
|
||||
mask_idx=tokenizers.sasa.mask_token_id,
|
||||
)
|
||||
tokens_dir["sasa"] = sasa_value
|
||||
|
||||
probs = sasa_logits.softmax(dim=-1)
|
||||
entropy = -(probs * sasa_logits.log_softmax(-1)).sum(-1)
|
||||
|
||||
track_sampling_metadata_dir["sasa"] = {"entropy": entropy}
|
||||
|
||||
# Sample function and residue annotations separately
|
||||
config = getattr(sampling_config, "function")
|
||||
@@ -612,25 +632,7 @@ def _sample_track(
|
||||
logits, temperature=temperature, top_p=sampling_track_config.top_p
|
||||
)
|
||||
log_probs = logits.log_softmax(-1)
|
||||
|
||||
# Do not sample at BOS and EOS tokens
|
||||
sampling_mask = torch.ones_like(tokens, dtype=torch.bool) # (B, L, )
|
||||
sampling_mask[:, 0] = False
|
||||
sampling_mask[:, -1] = False
|
||||
|
||||
# Do not sample at special token positions but allow sampling at mask token
|
||||
special_minus_mask = list(set(sampling_track_config.invalid_ids) - {mask_idx})
|
||||
if len(special_minus_mask) > 0:
|
||||
special_tokens = torch.tensor(special_minus_mask, device=tokens.device)
|
||||
assert special_tokens.numel() > 0
|
||||
sampling_mask = sampling_mask & (
|
||||
tokens[..., None] != special_tokens[None, :]
|
||||
).all(-1)
|
||||
|
||||
# Keep only samples from masked positions (if specified)
|
||||
if sampling_track_config.only_sample_masked_tokens:
|
||||
masked_tokens = tokens == mask_idx
|
||||
sampling_mask = sampling_mask & masked_tokens
|
||||
sampling_mask = get_sampling_mask(tokens, sampling_track_config, mask_idx)
|
||||
sampled_tokens = torch.where(sampling_mask, sampled_tokens, tokens)
|
||||
|
||||
return _compute_track_metadata(
|
||||
|
||||
@@ -248,6 +248,8 @@ def sample_residue_annotation_logits(
|
||||
def sample_sasa_logits(
|
||||
logits: torch.Tensor,
|
||||
tokens: torch.Tensor,
|
||||
sampling_track_config: SamplingTrackConfig,
|
||||
mask_idx: int,
|
||||
) -> torch.Tensor:
|
||||
sasa_probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
max_prob_idx = torch.argmax(sasa_probs, dim=-1)
|
||||
@@ -255,12 +257,11 @@ def sample_sasa_logits(
|
||||
sasa_bins = (sasa_bins[:-1] + sasa_bins[1:]) / 2
|
||||
sasa_bins = sasa_bins.to(sasa_probs.device)
|
||||
|
||||
sampling_mask = get_sampling_mask(tokens, sampling_track_config, mask_idx)
|
||||
# Adjust sasa_values based on max_prob_idx conditions
|
||||
sasa_value = torch.sum(sasa_probs[..., 3:-1] * sasa_bins, dim=-1)
|
||||
sasa_value[tokens == 0] = float("-inf")
|
||||
sasa_value[tokens == 1] = float("-inf")
|
||||
sasa_value[tokens == 2] = float("-inf")
|
||||
sasa_value[max_prob_idx == 18] = float("inf")
|
||||
sasa_value[~sampling_mask] = float("inf")
|
||||
|
||||
return sasa_value
|
||||
|
||||
@@ -294,3 +295,29 @@ def _tensorize_like(value: int | float | torch.Tensor, logits: torch.Tensor):
|
||||
if isinstance(value, (float, int)):
|
||||
value = torch.full_like(logits[..., 0], value, dtype=logits.dtype)
|
||||
return value.to(logits.device).expand_as(logits[..., 0]).reshape(-1)
|
||||
|
||||
|
||||
def get_sampling_mask(
|
||||
tokens: torch.Tensor,
|
||||
sampling_track_config: SamplingTrackConfig,
|
||||
mask_idx: int,
|
||||
):
|
||||
# Do not sample at BOS and EOS tokens
|
||||
sampling_mask = torch.ones_like(tokens, dtype=torch.bool) # (B, L, )
|
||||
sampling_mask[:, 0] = False
|
||||
sampling_mask[:, -1] = False
|
||||
|
||||
# Do not sample at special token positions but allow sampling at mask token
|
||||
special_minus_mask = list(set(sampling_track_config.invalid_ids) - {mask_idx})
|
||||
if len(special_minus_mask) > 0:
|
||||
special_tokens = torch.tensor(special_minus_mask, device=tokens.device)
|
||||
assert special_tokens.numel() > 0
|
||||
sampling_mask = sampling_mask & (
|
||||
tokens[..., None] != special_tokens[None, :]
|
||||
).all(-1)
|
||||
|
||||
# Keep only samples from masked positions (if specified)
|
||||
if sampling_track_config.only_sample_masked_tokens:
|
||||
masked_tokens = tokens == mask_idx
|
||||
sampling_mask = sampling_mask & masked_tokens
|
||||
return sampling_mask
|
||||
|
||||
Reference in New Issue
Block a user