This commit is contained in:
Jun Gong
2024-08-12 11:09:15 -07:00
committed by GitHub
parent aaabedcf58
commit 0774600af0
8 changed files with 155 additions and 56 deletions

View File

@@ -1 +1 @@
__version__ = "3.0.2"
__version__ = "3.0.3"

View File

@@ -2,6 +2,8 @@ import os
from esm.sdk.forge import ESM3ForgeInferenceClient
# Note: please do not import ESM3SageMakerClient here since that requires AWS SDK.
def client(
model="esm3-sm-open-v1",

View File

@@ -32,7 +32,7 @@ class ESMProtein(ProteinType):
# Tracks
sequence: str | None = None
secondary_structure: str | None = None
sasa: list[int | float | None] | None = None
sasa: list[float | None] | None = None
function_annotations: list[FunctionAnnotation] | None = None
coordinates: torch.Tensor | None = None

View File

@@ -1,5 +1,6 @@
import asyncio
from typing import Sequence
from urllib.parse import urljoin
import requests
import torch
@@ -117,7 +118,7 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
"condition_on_coordinates_only": config.condition_on_coordinates_only,
}
try:
data = self.__post("generate", request, input.potential_sequence_of_concern)
data = self._post("generate", request, input.potential_sequence_of_concern)
except RuntimeError as e:
return ESMProteinError(error_msg=str(e))
@@ -162,7 +163,7 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
}
try:
data = self.__post(
data = self._post(
"generate_tensor", request, input.potential_sequence_of_concern
)
except RuntimeError as e:
@@ -233,7 +234,7 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
"embedding_config": embedding_config,
}
try:
data = self.__post(
data = self._post(
"forward_and_sample", request, input.potential_sequence_of_concern
)
except RuntimeError as e:
@@ -297,7 +298,7 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
request = {"inputs": tracks, "model": self.model}
try:
data = self.__post("encode", request, input.potential_sequence_of_concern)
data = self._post("encode", request, input.potential_sequence_of_concern)
except RuntimeError as e:
return ESMProteinError(error_msg=str(e))
@@ -332,7 +333,7 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
}
try:
data = self.__post("decode", request, input.potential_sequence_of_concern)
data = self._post("decode", request, input.potential_sequence_of_concern)
except RuntimeError as e:
return ESMProteinError(error_msg=str(e))
@@ -380,7 +381,7 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
"inputs": req,
"logits_config": logits_config,
}
data = self.__post("logits", request, input.potential_sequence_of_concern)
data = self._post("logits", request, input.potential_sequence_of_concern)
def _maybe_logits(track: str):
if "logits" in data and track in data["logits"]:
@@ -401,11 +402,11 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
return output
def __post(self, endpoint, request, potential_sequence_of_concern):
def _post(self, endpoint, request, potential_sequence_of_concern):
request["potential_sequence_of_concern"] = potential_sequence_of_concern
response = requests.post(
f"{self.url}/api/v1/{endpoint}",
urljoin(self.url, f"/api/v1/{endpoint}"),
json=request,
headers=self.headers,
timeout=self.request_timeout,

56
esm/sdk/sagemaker.py Normal file
View File

@@ -0,0 +1,56 @@
import json
import boto3
from esm.sdk.forge import ESM3ForgeInferenceClient
class ESM3SageMakerClient(ESM3ForgeInferenceClient):
def __init__(self, endpoint_name: str, model: str):
"""ESM3 client that talks to a SageMaker endpoint.
Args:
endpoint_name: Name of the SageMaker endpoint.
model: Name of the ESM3 model.
"""
# Dummy URL and token to make ESM3ForgeInferenceClient happy.
super().__init__(model=model, url="", token="dummy")
self._endpoint_name = endpoint_name
self._model = model
self._client = boto3.client(service_name="sagemaker-runtime")
def _post(self, endpoint, request, potential_sequence_of_concern):
request["potential_sequence_of_concern"] = potential_sequence_of_concern
invocations_request = {
# Duplicate these fields at the top level to make Forge requests consistent.
"model": request["model"],
"request_id": "", # Forge specific field.
"user_id": "", # Forge specific field.
# Invocation data bits.
"api_ver": "v1", # Must be v1 right now.
"endpoint": endpoint,
# Wrapped request.
endpoint: request,
}
try:
response = self._client.invoke_endpoint(
EndpointName=self._endpoint_name,
ContentType="application/json",
Body=json.dumps(invocations_request),
)
except Exception as e:
raise RuntimeError(f"Failure in {endpoint}: {e}")
data = json.loads(response["Body"].read().decode())
# Response must match request.
assert data["endpoint"] == endpoint
# Get the actual responses under the endpoint key.
data = data[endpoint]
return data

View File

@@ -200,7 +200,18 @@ def decode_sasa(
if sasa_tokens[-1] != 0:
raise ValueError("SASA does not end with 0 corresponding to EOS token")
sasa_tokens = sasa_tokens[1:-1]
sasa = sasa_tokens.tolist()
if sasa_tokens.dtype in [
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.long,
]:
# Decode if int
sasa = sasa_tokenizer.decode_float(sasa_tokens)
else:
# If already float, just convert to list
sasa = sasa_tokens.tolist()
return sasa

View File

@@ -31,6 +31,7 @@ from esm.utils.misc import stack_variable_length_tensors
from esm.utils.noise_schedules import NOISE_SCHEDULE_REGISTRY
from esm.utils.sampling import (
_BatchedESMProteinTensor,
get_sampling_mask,
sample_function_logits,
sample_logits,
sample_residue_annotation_logits,
@@ -108,14 +109,7 @@ def iterative_sampling_raw(
raw_proteins: list[ESMProtein | ESMProteinError] = []
for output_tokens in output_tokens_list:
if isinstance(output_tokens, ESMProteinTensor):
try:
raw_proteins.append(client.decode(output_tokens))
except Exception:
# Print the input tokens so we know what is wrong.
print("Encountered exception during decoding:")
print(output_tokens)
# Re-raise.
raise
raw_proteins.append(client.decode(output_tokens))
elif isinstance(output_tokens, ESMProteinError):
raw_proteins.append(output_tokens)
else:
@@ -214,9 +208,15 @@ def _get_masked_positions(
track: str, tokens: torch.Tensor, mask_token_id: int
) -> torch.Tensor:
if track == "function":
return torch.all(tokens == mask_token_id, dim=-1).to(tokens.device)
mask = torch.all(tokens == mask_token_id, dim=-1).to(tokens.device)
else:
return tokens == mask_token_id
mask = tokens == mask_token_id
# Should not sample BOS and EOS positions.
mask[..., 0] = False
mask[..., -1] = False
return mask
def _get_iterative_sampling_mask_for_prompt_and_step(
@@ -407,6 +407,7 @@ def iterative_sampling_tokens(
per_prompt_forward_out,
sampling_config,
tokenizers,
decode_sasa_tokens=False,
)
# All positions sampled after _sample_per_prompt() above.
@@ -493,14 +494,21 @@ def _sample_per_prompt(
logits_output: LogitsOutput,
sampling_config: SamplingConfig,
tokenizers: TokenizerCollectionProtocol,
decode_sasa_tokens: bool = True,
) -> ForwardAndSampleOutput:
assert logits_output.logits is not None
def maybe_clone(x: torch.Tensor | None) -> torch.Tensor | None:
return x.clone() if x is not None else None
# Sampling
tokens_dir = {}
track_sampling_metadata_dir: dict[str, dict | None] = {}
for track in ["sequence", "structure", "secondary_structure"]:
integer_sampling_tracks = ["sequence", "structure", "secondary_structure"]
if not decode_sasa_tokens:
integer_sampling_tracks.append("sasa")
for track in integer_sampling_tracks:
config = getattr(sampling_config, track)
if config is None:
tokens_dir[track] = maybe_clone(getattr(protein, track))
@@ -514,21 +522,33 @@ def _sample_per_prompt(
tokens_dir[track] = sampling_metadata.pop("sampled_tokens") # (L,)
track_sampling_metadata_dir[track] = sampling_metadata
# Sample SASA seperately
config = getattr(sampling_config, "sasa")
track_sampling_metadata_dir["sasa"] = None
# Sample SASA seperately (if needed)
if decode_sasa_tokens:
config = getattr(sampling_config, "sasa")
track_sampling_metadata_dir["sasa"] = None
if config is not None:
if config.topk_logprobs > 0:
warn("For SASA sampling, 'topk_logprobs' is expected to be 0.")
sasa_logits = logits_output.logits.sasa[0, ...] # type: ignore
sasa_value = sample_sasa_logits(sasa_logits, protein.sasa[0, ...]) # type: ignore
tokens_dir["sasa"] = sasa_value
if config is None:
tokens_dir["sasa"] = maybe_clone(getattr(protein, "sasa"))
else:
if config.topk_logprobs > 0:
warn("For SASA sampling, 'topk_logprobs' is expected to be 0.")
probs = sasa_logits.softmax(dim=-1)
entropy = -(probs * sasa_logits.log_softmax(-1)).sum(-1)
assert logits_output.logits.sasa is not None
assert protein.sasa is not None
track_sampling_metadata_dir["sasa"] = {"entropy": entropy}
sasa_logits = logits_output.logits.sasa
sasa_value = sample_sasa_logits(
sasa_logits,
protein.sasa,
sampling_track_config=config,
mask_idx=tokenizers.sasa.mask_token_id,
)
tokens_dir["sasa"] = sasa_value
probs = sasa_logits.softmax(dim=-1)
entropy = -(probs * sasa_logits.log_softmax(-1)).sum(-1)
track_sampling_metadata_dir["sasa"] = {"entropy": entropy}
# Sample function and residue annotations separately
config = getattr(sampling_config, "function")
@@ -612,25 +632,7 @@ def _sample_track(
logits, temperature=temperature, top_p=sampling_track_config.top_p
)
log_probs = logits.log_softmax(-1)
# Do not sample at BOS and EOS tokens
sampling_mask = torch.ones_like(tokens, dtype=torch.bool) # (B, L, )
sampling_mask[:, 0] = False
sampling_mask[:, -1] = False
# Do not sample at special token positions but allow sampling at mask token
special_minus_mask = list(set(sampling_track_config.invalid_ids) - {mask_idx})
if len(special_minus_mask) > 0:
special_tokens = torch.tensor(special_minus_mask, device=tokens.device)
assert special_tokens.numel() > 0
sampling_mask = sampling_mask & (
tokens[..., None] != special_tokens[None, :]
).all(-1)
# Keep only samples from masked positions (if specified)
if sampling_track_config.only_sample_masked_tokens:
masked_tokens = tokens == mask_idx
sampling_mask = sampling_mask & masked_tokens
sampling_mask = get_sampling_mask(tokens, sampling_track_config, mask_idx)
sampled_tokens = torch.where(sampling_mask, sampled_tokens, tokens)
return _compute_track_metadata(

View File

@@ -248,6 +248,8 @@ def sample_residue_annotation_logits(
def sample_sasa_logits(
logits: torch.Tensor,
tokens: torch.Tensor,
sampling_track_config: SamplingTrackConfig,
mask_idx: int,
) -> torch.Tensor:
sasa_probs = torch.nn.functional.softmax(logits, dim=-1)
max_prob_idx = torch.argmax(sasa_probs, dim=-1)
@@ -255,12 +257,11 @@ def sample_sasa_logits(
sasa_bins = (sasa_bins[:-1] + sasa_bins[1:]) / 2
sasa_bins = sasa_bins.to(sasa_probs.device)
sampling_mask = get_sampling_mask(tokens, sampling_track_config, mask_idx)
# Adjust sasa_values based on max_prob_idx conditions
sasa_value = torch.sum(sasa_probs[..., 3:-1] * sasa_bins, dim=-1)
sasa_value[tokens == 0] = float("-inf")
sasa_value[tokens == 1] = float("-inf")
sasa_value[tokens == 2] = float("-inf")
sasa_value[max_prob_idx == 18] = float("inf")
sasa_value[~sampling_mask] = float("inf")
return sasa_value
@@ -294,3 +295,29 @@ def _tensorize_like(value: int | float | torch.Tensor, logits: torch.Tensor):
if isinstance(value, (float, int)):
value = torch.full_like(logits[..., 0], value, dtype=logits.dtype)
return value.to(logits.device).expand_as(logits[..., 0]).reshape(-1)
def get_sampling_mask(
tokens: torch.Tensor,
sampling_track_config: SamplingTrackConfig,
mask_idx: int,
):
# Do not sample at BOS and EOS tokens
sampling_mask = torch.ones_like(tokens, dtype=torch.bool) # (B, L, )
sampling_mask[:, 0] = False
sampling_mask[:, -1] = False
# Do not sample at special token positions but allow sampling at mask token
special_minus_mask = list(set(sampling_track_config.invalid_ids) - {mask_idx})
if len(special_minus_mask) > 0:
special_tokens = torch.tensor(special_minus_mask, device=tokens.device)
assert special_tokens.numel() > 0
sampling_mask = sampling_mask & (
tokens[..., None] != special_tokens[None, :]
).all(-1)
# Keep only samples from masked positions (if specified)
if sampling_track_config.only_sample_masked_tokens:
masked_tokens = tokens == mask_idx
sampling_mask = sampling_mask & masked_tokens
return sampling_mask