mirror of
https://github.com/evolutionaryscale/esm.git
synced 2026-06-04 17:14:23 +08:00
Sync from internal
This commit is contained in:
@@ -1 +1,2 @@
|
||||
__version__ = "3.0.7post1"
|
||||
__version__ = "3.0.8"
|
||||
|
||||
|
||||
@@ -10,11 +10,7 @@ from esm.layers.rotary import RotaryEmbedding
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
n_heads: int,
|
||||
bias: bool = False,
|
||||
qk_layernorm: bool = True,
|
||||
self, d_model: int, n_heads: int, bias: bool = False, qk_layernorm: bool = True
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -78,11 +78,7 @@ class GeometricReasoningOriginalImpl(nn.Module):
|
||||
affine.rot[..., None]
|
||||
.apply(rearrange(vec_rot, "... (h c) -> ... h c", c=3))
|
||||
.split(
|
||||
[
|
||||
self.v_heads,
|
||||
self.v_heads,
|
||||
self.v_heads * self.num_vector_messages,
|
||||
],
|
||||
[self.v_heads, self.v_heads, self.v_heads * self.num_vector_messages],
|
||||
dim=-2,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -2,9 +2,7 @@ import torch.nn as nn
|
||||
|
||||
|
||||
def RegressionHead(
|
||||
d_model: int,
|
||||
output_dim: int,
|
||||
hidden_dim: int | None = None,
|
||||
d_model: int, output_dim: int, hidden_dim: int | None = None
|
||||
) -> nn.Module:
|
||||
"""Single-hidden layer MLP for supervised output.
|
||||
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from esm.utils.constants.physics import (
|
||||
BB_COORDINATES,
|
||||
)
|
||||
from esm.utils.constants.physics import BB_COORDINATES
|
||||
from esm.utils.structure.affine3d import (
|
||||
Affine3D,
|
||||
RotationMatrix,
|
||||
|
||||
@@ -29,9 +29,7 @@ from esm.sdk.api import (
|
||||
ProteinType,
|
||||
SamplingConfig,
|
||||
)
|
||||
from esm.tokenization import (
|
||||
TokenizerCollectionProtocol,
|
||||
)
|
||||
from esm.tokenization import TokenizerCollectionProtocol
|
||||
from esm.utils import encoding
|
||||
from esm.utils.constants import esm3 as C
|
||||
from esm.utils.constants.models import (
|
||||
@@ -173,11 +171,7 @@ class OutputHeads(nn.Module):
|
||||
secondary_structure_logits = self.ss8_head(x)
|
||||
sasa_logits = self.sasa_head(x)
|
||||
function_logits = self.function_head(x)
|
||||
function_logits = einops.rearrange(
|
||||
function_logits,
|
||||
"... (k v) -> ... k v",
|
||||
k=8,
|
||||
)
|
||||
function_logits = einops.rearrange(function_logits, "... (k v) -> ... k v", k=8)
|
||||
|
||||
residue_logits = self.residue_head(x)
|
||||
|
||||
@@ -217,11 +211,7 @@ class ESM3(nn.Module, ESM3InferenceClient):
|
||||
super().__init__()
|
||||
self.encoder = EncodeInputs(d_model)
|
||||
self.transformer = TransformerStack(
|
||||
d_model,
|
||||
n_heads,
|
||||
v_heads,
|
||||
n_layers,
|
||||
mask_and_zero_frameless=True,
|
||||
d_model, n_heads, v_heads, n_layers, mask_and_zero_frameless=True
|
||||
)
|
||||
self.output_heads = OutputHeads(d_model)
|
||||
|
||||
@@ -237,9 +227,7 @@ class ESM3(nn.Module, ESM3InferenceClient):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
model_name: str = ESM3_OPEN_SMALL,
|
||||
device: torch.device | None = None,
|
||||
cls, model_name: str = ESM3_OPEN_SMALL, device: torch.device | None = None
|
||||
) -> ESM3:
|
||||
from esm.pretrained import load_local_model
|
||||
|
||||
@@ -489,15 +477,14 @@ class ESM3(nn.Module, ESM3InferenceClient):
|
||||
reference_sequence = encoding.get_default_sequence(sequence_length - 2)
|
||||
else:
|
||||
reference_sequence = input.sequence
|
||||
(
|
||||
function_tokens,
|
||||
residue_annotation_tokens,
|
||||
) = encoding.tokenize_function_annotations(
|
||||
input.function_annotations,
|
||||
reference_sequence=reference_sequence,
|
||||
function_tokenizer=self.tokenizers.function,
|
||||
residue_annotation_tokenizer=self.tokenizers.residue_annotations,
|
||||
add_special_tokens=True,
|
||||
(function_tokens, residue_annotation_tokens) = (
|
||||
encoding.tokenize_function_annotations(
|
||||
input.function_annotations,
|
||||
reference_sequence=reference_sequence,
|
||||
function_tokenizer=self.tokenizers.function,
|
||||
residue_annotation_tokenizer=self.tokenizers.residue_annotations,
|
||||
add_special_tokens=True,
|
||||
)
|
||||
)
|
||||
|
||||
return ESMProteinTensor(
|
||||
@@ -510,10 +497,7 @@ class ESM3(nn.Module, ESM3InferenceClient):
|
||||
coordinates=coordinates,
|
||||
).to(next(self.parameters()).device)
|
||||
|
||||
def decode(
|
||||
self,
|
||||
input: ESMProteinTensor,
|
||||
) -> ESMProtein:
|
||||
def decode(self, input: ESMProteinTensor) -> ESMProtein:
|
||||
return decode_protein_tensor(
|
||||
input=input,
|
||||
tokenizers=self.tokenizers,
|
||||
@@ -613,10 +597,7 @@ class ESM3(nn.Module, ESM3InferenceClient):
|
||||
|
||||
logits_output: LogitsOutput = _batch_forward(self, batched_protein)
|
||||
forward_and_sample_out: ForwardAndSampleOutput = _sample_per_prompt(
|
||||
batched_protein,
|
||||
logits_output,
|
||||
sampling_config,
|
||||
self.tokenizers,
|
||||
batched_protein, logits_output, sampling_config, self.tokenizers
|
||||
)
|
||||
|
||||
# There is only 1 prompt to sample for.
|
||||
|
||||
@@ -167,8 +167,7 @@ class FunctionTokenDecoder(nn.Module):
|
||||
# Apply depth-position offset to use distinct vocabs. See __init__ for
|
||||
# explaination.
|
||||
vocab_offsets = self.config.function_token_vocab_size * torch.arange(
|
||||
self.config.function_token_depth,
|
||||
device=token_ids.device,
|
||||
self.config.function_token_depth, device=token_ids.device
|
||||
)
|
||||
inputs = token_ids + vocab_offsets[None, :]
|
||||
|
||||
@@ -251,8 +250,7 @@ class FunctionTokenDecoder(nn.Module):
|
||||
annotations.append(annotation)
|
||||
|
||||
annotations = merge_annotations(
|
||||
annotations,
|
||||
merge_gap_max=annotation_gap_merge_max,
|
||||
annotations, merge_gap_max=annotation_gap_merge_max
|
||||
)
|
||||
|
||||
# Drop very small annotations.
|
||||
|
||||
@@ -87,10 +87,7 @@ class PairwisePredictionHead(nn.Module):
|
||||
|
||||
prod = q[:, None, :, :] * k[:, :, None, :]
|
||||
diff = q[:, None, :, :] - k[:, :, None, :]
|
||||
x_2d = [
|
||||
prod,
|
||||
diff,
|
||||
]
|
||||
x_2d = [prod, diff]
|
||||
if pairwise is not None:
|
||||
x_2d.append(pairwise)
|
||||
x = torch.cat(x_2d, dim=-1)
|
||||
@@ -289,11 +286,7 @@ class StructureTokenEncoder(nn.Module):
|
||||
with torch.no_grad(), torch.cuda.amp.autocast(enabled=False): # type: ignore
|
||||
ca = coords[..., 1, :]
|
||||
edges, edge_mask = knn_graph(
|
||||
ca,
|
||||
coord_mask,
|
||||
padding_mask,
|
||||
sequence_id,
|
||||
no_knn=knn,
|
||||
ca, coord_mask, padding_mask, sequence_id, no_knn=knn
|
||||
)
|
||||
|
||||
return edges, edge_mask
|
||||
@@ -333,12 +326,7 @@ class StructureTokenEncoder(nn.Module):
|
||||
|
||||
|
||||
class StructureTokenDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model,
|
||||
n_heads,
|
||||
n_layers,
|
||||
):
|
||||
def __init__(self, d_model, n_heads, n_layers):
|
||||
super().__init__()
|
||||
self.decoder_channels = d_model
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from typing import Sequence
|
||||
from typing import List, Sequence
|
||||
|
||||
import attr
|
||||
import torch
|
||||
@@ -19,14 +19,10 @@ from esm.utils.misc import (
|
||||
)
|
||||
from esm.utils.structure.protein_chain import ProteinChain
|
||||
from esm.utils.structure.protein_complex import ProteinComplex
|
||||
from esm.utils.types import (
|
||||
FunctionAnnotation,
|
||||
PathOrBuffer,
|
||||
)
|
||||
from esm.utils.types import FunctionAnnotation, PathOrBuffer
|
||||
|
||||
|
||||
class ProteinType(ABC):
|
||||
...
|
||||
class ProteinType(ABC): ...
|
||||
|
||||
|
||||
## Basic Types
|
||||
@@ -184,6 +180,9 @@ class ESMProteinTensor(ProteinType):
|
||||
# Such sequences may not go through standard safety filter for approved users.
|
||||
# Reach out if interested in using this.
|
||||
potential_sequence_of_concern: bool = False
|
||||
# Control vectors are vectors added to each layer of the model to nudge hidden states to the desired direction.
|
||||
# len(control_vectors) == number of blocks in the model. Each vector in the list have the shape of (batch size, sequence length, hidden dim)
|
||||
# so it can be added to the corresponding layer in the model
|
||||
|
||||
def _detect_attribute(self, func, msg):
|
||||
mapped = {
|
||||
@@ -260,20 +259,40 @@ class ESMProteinError(Exception, ProteinType):
|
||||
class GenerationConfig:
|
||||
track: str = ""
|
||||
invalid_ids: Sequence[int] = []
|
||||
schedule: str = "cosine"
|
||||
# Controls the number of tokens to unmask during each round of iterative generation.
|
||||
schedule: str = attr.field(
|
||||
validator=attr.validators.in_(["cosine", "linear"]), default="cosine"
|
||||
)
|
||||
# Controls which tokens to unmask during each round of iterative generation.
|
||||
# "random" will unmask a correct number of tokens randomly.
|
||||
# "entropy" will unmask the tokens with the lowest logit entropy first.
|
||||
strategy: str = attr.field(
|
||||
validator=attr.validators.in_(["random", "entropy"]), default="entropy"
|
||||
)
|
||||
# Set this to a higher value for better generation results.
|
||||
# Note that this needs to be less than or equal to the sequence length.
|
||||
num_steps: int = 1
|
||||
temperature: float = 1.0
|
||||
temperature_annealing: bool = False
|
||||
top_p: float = 1.0
|
||||
condition_on_coordinates_only: bool = True
|
||||
|
||||
def use_entropy_based_unmasking_strategy(self):
|
||||
"""Use entropy based unmasking strategy during generation."""
|
||||
self.schedule = "cosine"
|
||||
self.strategy = "entropy"
|
||||
self.temperature_annealing = False
|
||||
|
||||
def use_generative_unmasking_strategy(self):
|
||||
"""Use an unmasking strategy that produces more variety of generations."""
|
||||
self.schedule = "cosine"
|
||||
self.strategy = "random"
|
||||
self.temperature_annealing = True
|
||||
|
||||
|
||||
@define
|
||||
class InverseFoldingConfig:
|
||||
invalid_ids: Sequence[int] = []
|
||||
schedule: str = "cosine"
|
||||
num_steps: int = 1
|
||||
temperature: float = 1.0
|
||||
|
||||
|
||||
@@ -370,9 +389,7 @@ class ESM3InferenceClient(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
def batch_generate(
|
||||
self,
|
||||
inputs: Sequence[ProteinType],
|
||||
configs: Sequence[GenerationConfig],
|
||||
self, inputs: Sequence[ProteinType], configs: Sequence[GenerationConfig]
|
||||
) -> Sequence[ProteinType]:
|
||||
# Same as generate(...), but generates a batch of proteins at once.
|
||||
raise NotImplementedError
|
||||
|
||||
109
esm/sdk/forge.py
109
esm/sdk/forge.py
@@ -5,12 +5,7 @@ from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_result,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
from tenacity import retry, retry_if_result, stop_after_attempt, wait_exponential
|
||||
|
||||
from esm.sdk.api import (
|
||||
ESM3InferenceClient,
|
||||
@@ -20,6 +15,7 @@ from esm.sdk.api import (
|
||||
ForwardAndSampleOutput,
|
||||
ForwardTrackData,
|
||||
GenerationConfig,
|
||||
InverseFoldingConfig,
|
||||
LogitsConfig,
|
||||
LogitsOutput,
|
||||
ProteinType,
|
||||
@@ -55,7 +51,15 @@ def log_retry_attempt(retry_state):
|
||||
)
|
||||
|
||||
|
||||
class FoldForgeInferenceClient:
|
||||
def _validate_protein_tensor_input(input):
|
||||
if not isinstance(input, ESMProteinTensor):
|
||||
raise ValueError(
|
||||
"Input must be an ESMProteinTensor instance. "
|
||||
"Use encode() API to encode an ESMProtein into ESMProteinTensor."
|
||||
)
|
||||
|
||||
|
||||
class SequenceStructureForgeInferenceClient:
|
||||
def __init__(
|
||||
self,
|
||||
url: str = "https://forge.evolutionaryscale.ai",
|
||||
@@ -73,31 +77,51 @@ class FoldForgeInferenceClient:
|
||||
|
||||
def fold(
|
||||
self,
|
||||
model_name: str,
|
||||
sequence: str,
|
||||
potential_sequence_of_concern: bool,
|
||||
) -> torch.Tensor | ESMProteinError:
|
||||
request = {
|
||||
"model": model_name,
|
||||
"sequence": sequence,
|
||||
}
|
||||
model_name: str | None = None,
|
||||
) -> ESMProtein | ESMProteinError:
|
||||
request = {"sequence": sequence}
|
||||
if model_name is not None:
|
||||
request["model"] = model_name
|
||||
try:
|
||||
data = self._post(
|
||||
"fold",
|
||||
request,
|
||||
potential_sequence_of_concern,
|
||||
)
|
||||
data = self._post("fold", request, potential_sequence_of_concern)
|
||||
except ESMProteinError as e:
|
||||
return e
|
||||
|
||||
return data["coordinates"]
|
||||
return ESMProtein(
|
||||
coordinates=maybe_tensor(data["coordinates"], convert_none_to_nan=True)
|
||||
)
|
||||
|
||||
def inverse_fold(
|
||||
self,
|
||||
coordinates: torch.Tensor,
|
||||
config: InverseFoldingConfig,
|
||||
potential_sequence_of_concern: bool,
|
||||
model_name: str | None = None,
|
||||
) -> ESMProtein | ESMProteinError:
|
||||
inverse_folding_config = {
|
||||
"invalid_ids": config.invalid_ids,
|
||||
"temperature": config.temperature,
|
||||
}
|
||||
request = {
|
||||
"coordinates": maybe_list(coordinates, convert_nan_to_none=True),
|
||||
"inverse_folding_config": inverse_folding_config,
|
||||
}
|
||||
if model_name is not None:
|
||||
request["model"] = model_name
|
||||
try:
|
||||
data = self._post("inverse_fold", request, potential_sequence_of_concern)
|
||||
except ESMProteinError as e:
|
||||
return e
|
||||
|
||||
return ESMProtein(sequence=data["sequence"])
|
||||
|
||||
def _post(self, endpoint, request, potential_sequence_of_concern):
|
||||
request["potential_sequence_of_concern"] = potential_sequence_of_concern
|
||||
|
||||
model_name_url = request["model"] if request["model"] != "esm3" else "api"
|
||||
response = requests.post(
|
||||
urljoin(self.url, f"/{model_name_url}/v1/{endpoint}"),
|
||||
urljoin(self.url, f"/api/v1/{endpoint}"),
|
||||
json=request,
|
||||
headers=self.headers,
|
||||
timeout=self.request_timeout,
|
||||
@@ -115,6 +139,11 @@ class FoldForgeInferenceClient:
|
||||
if "outputs" not in data and "data" in data:
|
||||
data = data["data"]
|
||||
|
||||
# Print warning message if there is any.
|
||||
if "warning_messages" in data and data["warning_messages"] is not None:
|
||||
for msg in data["warning_messages"]:
|
||||
print("\033[31m", msg, "\033[0m")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@@ -174,18 +203,13 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
||||
output = self.__generate_protein_tensor(input, config)
|
||||
else:
|
||||
return ESMProteinError(
|
||||
error_code=500,
|
||||
error_msg=f"Unknown input type {type(input)}",
|
||||
error_code=500, error_msg=f"Unknown input type {type(input)}"
|
||||
)
|
||||
|
||||
if (
|
||||
isinstance(output, ESMProtein)
|
||||
and isinstance(input, ESMProtein)
|
||||
and config.track
|
||||
not in [
|
||||
"function",
|
||||
"residue_annotations",
|
||||
]
|
||||
and config.track not in ["function", "residue_annotations"]
|
||||
):
|
||||
# Function and residue annotation encoding/decoding is lossy
|
||||
# There is no guarantee that decoding encoded tokens will yield the same input
|
||||
@@ -218,9 +242,7 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
||||
return [_capture_exception(r) for r in results]
|
||||
|
||||
def __generate_protein(
|
||||
self,
|
||||
input: ESMProtein,
|
||||
config: GenerationConfig,
|
||||
self, input: ESMProtein, config: GenerationConfig
|
||||
) -> ESMProtein | ESMProteinError:
|
||||
req = {}
|
||||
req["sequence"] = input.sequence
|
||||
@@ -261,9 +283,7 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
||||
)
|
||||
|
||||
def __generate_protein_tensor(
|
||||
self,
|
||||
input: ESMProteinTensor,
|
||||
config: GenerationConfig,
|
||||
self, input: ESMProteinTensor, config: GenerationConfig
|
||||
) -> ESMProteinTensor | ESMProteinError:
|
||||
req = {}
|
||||
req["sequence"] = maybe_list(input.sequence)
|
||||
@@ -316,6 +336,7 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
||||
def forward_and_sample(
|
||||
self, input: ESMProteinTensor, sampling_configuration: SamplingConfig
|
||||
) -> ForwardAndSampleOutput | ESMProteinError:
|
||||
_validate_protein_tensor_input(input)
|
||||
validate_sampling_config(sampling_configuration, on_invalid="raise")
|
||||
|
||||
req = {}
|
||||
@@ -441,10 +462,9 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
||||
)
|
||||
|
||||
@retry_decorator
|
||||
def decode(
|
||||
self,
|
||||
input: ESMProteinTensor,
|
||||
) -> ESMProtein | ESMProteinError:
|
||||
def decode(self, input: ESMProteinTensor) -> ESMProtein | ESMProteinError:
|
||||
_validate_protein_tensor_input(input)
|
||||
|
||||
tokens = {}
|
||||
tokens["sequence"] = maybe_list(input.sequence)
|
||||
tokens["structure"] = maybe_list(input.structure)
|
||||
@@ -454,10 +474,7 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
||||
tokens["residue_annotation"] = maybe_list(input.residue_annotations)
|
||||
tokens["coordinates"] = maybe_list(input.coordinates, convert_nan_to_none=True)
|
||||
|
||||
request = {
|
||||
"model": self.model,
|
||||
"inputs": tokens,
|
||||
}
|
||||
request = {"model": self.model, "inputs": tokens}
|
||||
|
||||
try:
|
||||
data = self._post("decode", request, input.potential_sequence_of_concern)
|
||||
@@ -482,6 +499,8 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
||||
def logits(
|
||||
self, input: ESMProteinTensor, config: LogitsConfig = LogitsConfig()
|
||||
) -> LogitsOutput | ESMProteinError:
|
||||
_validate_protein_tensor_input(input)
|
||||
|
||||
# Note: using raw model forwards is discouraged because of the byte size
|
||||
# of the logits.
|
||||
# Please use forward_and_sample instead.
|
||||
@@ -504,11 +523,7 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
||||
"return_embeddings": config.return_embeddings,
|
||||
}
|
||||
|
||||
request = {
|
||||
"model": self.model,
|
||||
"inputs": req,
|
||||
"logits_config": logits_config,
|
||||
}
|
||||
request = {"model": self.model, "inputs": req, "logits_config": logits_config}
|
||||
|
||||
try:
|
||||
data = self._post("logits", request, input.potential_sequence_of_concern)
|
||||
|
||||
@@ -2,7 +2,61 @@ import json
|
||||
|
||||
import boto3
|
||||
|
||||
from esm.sdk.forge import ESM3ForgeInferenceClient
|
||||
from esm.sdk.forge import (
|
||||
ESM3ForgeInferenceClient,
|
||||
SequenceStructureForgeInferenceClient,
|
||||
)
|
||||
|
||||
|
||||
class SequenceStructureSageMakerClient(SequenceStructureForgeInferenceClient):
|
||||
def __init__(self, endpoint_name: str):
|
||||
"""SequenceStructure (folding and inverse folding) client that talks to a SageMaker endpoint.
|
||||
|
||||
Args:
|
||||
endpoint_name: Name of the SageMaker endpoint.
|
||||
"""
|
||||
# Dummy URL and token to make SequenceStructureForgeInferenceClient happy.
|
||||
super().__init__(url="", token="dummy")
|
||||
|
||||
self._endpoint_name = endpoint_name
|
||||
|
||||
self._client = boto3.client(service_name="sagemaker-runtime")
|
||||
|
||||
def _post(self, endpoint, request, potential_sequence_of_concern):
|
||||
request["potential_sequence_of_concern"] = potential_sequence_of_concern
|
||||
request["model"] = request.get("model", None)
|
||||
invocations_request = {
|
||||
# Duplicate these fields at the top level to make Forge requests consistent.
|
||||
"model": request["model"],
|
||||
"request_id": "", # Forge specific field.
|
||||
"user_id": "", # Forge specific field.
|
||||
# Invocation data bits.
|
||||
"api_ver": "v1", # Must be v1 right now.
|
||||
"endpoint": endpoint,
|
||||
# Wrapped request.
|
||||
endpoint: request,
|
||||
}
|
||||
|
||||
try:
|
||||
response = self._client.invoke_endpoint(
|
||||
EndpointName=self._endpoint_name,
|
||||
ContentType="application/json",
|
||||
Body=json.dumps(invocations_request),
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failure in {endpoint}: {e}") from e
|
||||
|
||||
data = json.loads(response["Body"].read().decode())
|
||||
|
||||
# Response must match request.
|
||||
assert (
|
||||
data["endpoint"] == endpoint
|
||||
), f"Response endpoint is {data['endpoint']} but request is {endpoint}"
|
||||
|
||||
# Get the actual responses under the endpoint key.
|
||||
data = data[endpoint]
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class ESM3SageMakerClient(ESM3ForgeInferenceClient):
|
||||
|
||||
@@ -120,8 +120,7 @@ class InterProQuantizedTokenizer(EsmTokenizerBase):
|
||||
def _tfidf(self) -> tfidf.TFIDFModel:
|
||||
"""Creates TF-IDF model for encoding function keywords."""
|
||||
return tfidf.TFIDFModel(
|
||||
vocabulary_path=self.keyword_vocabulary_path,
|
||||
idf_path=self.keyword_idf_path,
|
||||
vocabulary_path=self.keyword_vocabulary_path, idf_path=self.keyword_idf_path
|
||||
)
|
||||
|
||||
@cached_property
|
||||
@@ -205,9 +204,7 @@ class InterProQuantizedTokenizer(EsmTokenizerBase):
|
||||
return tokens
|
||||
|
||||
def _function_text_hash(
|
||||
self,
|
||||
labels: Collection[str],
|
||||
keyword_mask: np.ndarray | None = None,
|
||||
self, labels: Collection[str], keyword_mask: np.ndarray | None = None
|
||||
) -> np.ndarray | None:
|
||||
"""Applies a locality sensitive hash (LSH) to function text.
|
||||
|
||||
@@ -295,9 +292,7 @@ class InterProQuantizedTokenizer(EsmTokenizerBase):
|
||||
raise ValueError(f"Unknown token: {token}")
|
||||
|
||||
def batch_encode(
|
||||
self,
|
||||
token_batch: list[list[str]],
|
||||
add_special_tokens: bool = True,
|
||||
self, token_batch: list[list[str]], add_special_tokens: bool = True
|
||||
) -> torch.Tensor:
|
||||
"""Encodes batch of function tokens.
|
||||
|
||||
@@ -312,8 +307,7 @@ class InterProQuantizedTokenizer(EsmTokenizerBase):
|
||||
for tokens in token_batch
|
||||
]
|
||||
return stack_variable_length_tensors(
|
||||
encoded,
|
||||
constant_value=self.vocab_to_index["<pad>"],
|
||||
encoded, constant_value=self.vocab_to_index["<pad>"]
|
||||
)
|
||||
|
||||
def decode(self, encoded: torch.Tensor):
|
||||
|
||||
@@ -13,11 +13,7 @@ Sample = dict[str, Any]
|
||||
|
||||
|
||||
class ResidueAnnotationsTokenizer(EsmTokenizerBase):
|
||||
def __init__(
|
||||
self,
|
||||
csv_path: str | None = None,
|
||||
max_annotations: int = 16,
|
||||
):
|
||||
def __init__(self, csv_path: str | None = None, max_annotations: int = 16):
|
||||
if csv_path is None:
|
||||
csv_path = str(C.data_root() / C.RESID_CSV)
|
||||
self.csv_path = csv_path
|
||||
|
||||
@@ -31,7 +31,7 @@ class SASADiscretizingTokenizer(EsmTokenizerBase):
|
||||
return self.special_tokens + range_tokens
|
||||
|
||||
@cached_property
|
||||
def midpoints(self) -> list[float]:
|
||||
def midpoints_tensor(self) -> torch.Tensor:
|
||||
"""Midpoints of the SASA token ranges."""
|
||||
boundaries = [0] + self._boundaries + [self._boundaries[-1] * 2]
|
||||
midpoint_tokens = [
|
||||
@@ -39,7 +39,11 @@ class SASADiscretizingTokenizer(EsmTokenizerBase):
|
||||
for low, high in zip(boundaries[:-1], boundaries[1:])
|
||||
]
|
||||
midpoint_tokens = [float("nan"), float("nan"), float("nan")] + midpoint_tokens
|
||||
return midpoint_tokens
|
||||
return torch.Tensor(midpoint_tokens)
|
||||
|
||||
def midpoints(self) -> list[float]:
|
||||
"""Midpoints of the SASA token ranges."""
|
||||
return self.midpoints_tensor.tolist()
|
||||
|
||||
@cached_property
|
||||
def vocab_to_index(self) -> dict[str, int]:
|
||||
@@ -86,7 +90,11 @@ class SASADiscretizingTokenizer(EsmTokenizerBase):
|
||||
|
||||
def decode_float(self, encoded: torch.Tensor) -> list[float]:
|
||||
"""Decodes SASA token ids into float values."""
|
||||
return [self.midpoints[token_id] for token_id in encoded]
|
||||
decoded = self.midpoints_tensor[encoded.cpu()]
|
||||
nan_mask = torch.isnan(decoded)
|
||||
np_arr = decoded.numpy()
|
||||
np_arr[nan_mask.numpy()] = None
|
||||
return np_arr.tolist()
|
||||
|
||||
def decode(self, encoded: torch.Tensor) -> str:
|
||||
"""Decodes SASA token ids."""
|
||||
|
||||
@@ -40,9 +40,7 @@ class EsmSequenceTokenizer(PreTrainedTokenizerFast, EsmTokenizerBase):
|
||||
self.cb_token = chain_break_token
|
||||
additional_special_tokens = [chain_break_token]
|
||||
|
||||
tokenizer.add_special_tokens(
|
||||
special_tokens,
|
||||
)
|
||||
tokenizer.add_special_tokens(special_tokens)
|
||||
|
||||
# This is where we configure the automatic addition of special tokens when we call
|
||||
# tokenizer(text, add_special_tokens=True). Note that you can also configure how two
|
||||
|
||||
@@ -3,56 +3,42 @@ from typing import Protocol, runtime_checkable
|
||||
|
||||
@runtime_checkable
|
||||
class EsmTokenizerBase(Protocol):
|
||||
def encode(self, *args, **kwargs):
|
||||
...
|
||||
def encode(self, *args, **kwargs): ...
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
...
|
||||
def decode(self, *args, **kwargs): ...
|
||||
|
||||
@property
|
||||
def mask_token(self) -> str:
|
||||
...
|
||||
def mask_token(self) -> str: ...
|
||||
|
||||
@property
|
||||
def mask_token_id(self) -> int:
|
||||
...
|
||||
def mask_token_id(self) -> int: ...
|
||||
|
||||
@property
|
||||
def bos_token(self) -> str:
|
||||
...
|
||||
def bos_token(self) -> str: ...
|
||||
|
||||
@property
|
||||
def bos_token_id(self) -> int:
|
||||
...
|
||||
def bos_token_id(self) -> int: ...
|
||||
|
||||
@property
|
||||
def eos_token(self) -> str:
|
||||
...
|
||||
def eos_token(self) -> str: ...
|
||||
|
||||
@property
|
||||
def eos_token_id(self) -> int:
|
||||
...
|
||||
def eos_token_id(self) -> int: ...
|
||||
|
||||
@property
|
||||
def pad_token(self) -> str:
|
||||
...
|
||||
def pad_token(self) -> str: ...
|
||||
|
||||
@property
|
||||
def pad_token_id(self) -> int:
|
||||
...
|
||||
def pad_token_id(self) -> int: ...
|
||||
|
||||
@property
|
||||
def chain_break_token(self) -> str:
|
||||
...
|
||||
def chain_break_token(self) -> str: ...
|
||||
|
||||
@property
|
||||
def chain_break_token_id(self) -> int:
|
||||
...
|
||||
def chain_break_token_id(self) -> int: ...
|
||||
|
||||
@property
|
||||
def all_token_ids(self):
|
||||
...
|
||||
def all_token_ids(self): ...
|
||||
|
||||
@property
|
||||
def special_token_ids(self):
|
||||
...
|
||||
def special_token_ids(self): ...
|
||||
|
||||
@@ -112,9 +112,7 @@ INTERPRO_HIERARCHY = IN_REPO_DATA_FOLDER / "ParentChildTreeFile.txt"
|
||||
INTERPRO2GO = IN_REPO_DATA_FOLDER / "ParentChildTreeFile.txt"
|
||||
INTERPRO_2ID = "data/tag_dict_4_safety_filtered.json"
|
||||
|
||||
LSH_TABLE_PATHS = {
|
||||
"8bit": "data/hyperplanes_8bit_58641.npz",
|
||||
}
|
||||
LSH_TABLE_PATHS = {"8bit": "data/hyperplanes_8bit_58641.npz"}
|
||||
|
||||
KEYWORDS_VOCABULARY = (
|
||||
IN_REPO_DATA_FOLDER / "keyword_vocabulary_safety_filtered_58641.txt"
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import warnings
|
||||
from typing import cast
|
||||
|
||||
import attr
|
||||
import torch
|
||||
@@ -31,7 +32,7 @@ from esm.utils.function.encode_decode import (
|
||||
decode_function_tokens,
|
||||
decode_residue_annotation_tokens,
|
||||
)
|
||||
from esm.utils.misc import list_nan_to_none
|
||||
from esm.utils.misc import maybe_list
|
||||
from esm.utils.structure.protein_chain import ProteinChain
|
||||
from esm.utils.types import FunctionAnnotation
|
||||
|
||||
@@ -130,15 +131,10 @@ def _bos_eos_warn(msg: str, tensor: torch.Tensor, tok: EsmTokenizerBase):
|
||||
|
||||
|
||||
def decode_sequence(
|
||||
sequence_tokens: torch.Tensor,
|
||||
sequence_tokenizer: EsmSequenceTokenizer,
|
||||
**kwargs,
|
||||
sequence_tokens: torch.Tensor, sequence_tokenizer: EsmSequenceTokenizer, **kwargs
|
||||
) -> str:
|
||||
_bos_eos_warn("Sequence", sequence_tokens, sequence_tokenizer)
|
||||
sequence = sequence_tokenizer.decode(
|
||||
sequence_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
sequence = sequence_tokenizer.decode(sequence_tokens, **kwargs)
|
||||
sequence = sequence.replace(" ", "")
|
||||
sequence = sequence.replace(sequence_tokenizer.mask_token, C.MASK_STR_SHORT)
|
||||
sequence = sequence.replace(sequence_tokenizer.cls_token, "")
|
||||
@@ -185,20 +181,16 @@ def decode_structure(
|
||||
|
||||
|
||||
def decode_secondary_structure(
|
||||
secondary_structure_tokens: torch.Tensor,
|
||||
ss_tokenizer: SecondaryStructureTokenizer,
|
||||
secondary_structure_tokens: torch.Tensor, ss_tokenizer: SecondaryStructureTokenizer
|
||||
) -> str:
|
||||
_bos_eos_warn("Secondary structure", secondary_structure_tokens, ss_tokenizer)
|
||||
secondary_structure_tokens = secondary_structure_tokens[1:-1]
|
||||
secondary_structure = ss_tokenizer.decode(
|
||||
secondary_structure_tokens,
|
||||
)
|
||||
secondary_structure = ss_tokenizer.decode(secondary_structure_tokens)
|
||||
return secondary_structure
|
||||
|
||||
|
||||
def decode_sasa(
|
||||
sasa_tokens: torch.Tensor,
|
||||
sasa_tokenizer: SASADiscretizingTokenizer,
|
||||
sasa_tokens: torch.Tensor, sasa_tokenizer: SASADiscretizingTokenizer
|
||||
) -> list[float]:
|
||||
if sasa_tokens[0] != 0:
|
||||
raise ValueError("SASA does not start with 0 corresponding to BOS token")
|
||||
@@ -213,12 +205,13 @@ def decode_sasa(
|
||||
torch.long,
|
||||
]:
|
||||
# Decode if int
|
||||
# handles turning NaN's into None's
|
||||
sasa = sasa_tokenizer.decode_float(sasa_tokens)
|
||||
else:
|
||||
# If already float, just convert to list
|
||||
sasa = sasa_tokens.tolist()
|
||||
sasa = cast(list[float], maybe_list(sasa_tokens, convert_nan_to_none=True))
|
||||
|
||||
return list_nan_to_none(sasa)
|
||||
return sasa
|
||||
|
||||
|
||||
def decode_function_annotations(
|
||||
|
||||
@@ -97,9 +97,7 @@ def tokenize_structure(
|
||||
# Add space for BOS and EOS tokens
|
||||
if add_special_tokens:
|
||||
coordinates = F.pad(
|
||||
coordinates,
|
||||
(0, 0, 0, 0, left_pad, right_pad),
|
||||
value=torch.inf,
|
||||
coordinates, (0, 0, 0, 0, left_pad, right_pad), value=torch.inf
|
||||
)
|
||||
plddt = F.pad(plddt, (left_pad, right_pad), value=0)
|
||||
structure_tokens = F.pad(
|
||||
@@ -171,8 +169,7 @@ def tokenize_function_annotations(
|
||||
|
||||
# Tokenized Defaults
|
||||
def get_default_sequence_tokens(
|
||||
sequence_length: int,
|
||||
sequence_tokenizer: EsmSequenceTokenizer,
|
||||
sequence_length: int, sequence_tokenizer: EsmSequenceTokenizer
|
||||
) -> torch.Tensor:
|
||||
assert sequence_tokenizer.mask_token_id is not None
|
||||
assert sequence_tokenizer.bos_token_id is not None
|
||||
@@ -191,10 +188,7 @@ def get_default_structure_tokens(
|
||||
sequence_length: int, structure_tokenizer: StructureTokenizer
|
||||
) -> torch.Tensor:
|
||||
structure_tokens = (
|
||||
torch.ones(
|
||||
(sequence_length + 2,),
|
||||
dtype=torch.int64,
|
||||
)
|
||||
torch.ones((sequence_length + 2,), dtype=torch.int64)
|
||||
* structure_tokenizer.mask_token_id
|
||||
)
|
||||
# Always include BOS and EOS tokens
|
||||
@@ -241,10 +235,7 @@ def get_default_residue_annotation_tokens(
|
||||
sequence_length: int, residue_annotation_tokenizer: ResidueAnnotationsTokenizer
|
||||
) -> torch.Tensor:
|
||||
residue_annotation_tokens = (
|
||||
torch.ones(
|
||||
(sequence_length + 2, C.MAX_RESIDUE_ANNOTATIONS),
|
||||
dtype=torch.int64,
|
||||
)
|
||||
torch.ones((sequence_length + 2, C.MAX_RESIDUE_ANNOTATIONS), dtype=torch.int64)
|
||||
* residue_annotation_tokenizer.pad_token_id
|
||||
)
|
||||
# Always include BOS and EOS tokens
|
||||
|
||||
@@ -59,8 +59,7 @@ def encode_function_annotations(
|
||||
|
||||
# Convert function token FunctionAnnotations -> Tensor
|
||||
function_tokens = function_tokens_tokenizer.tokenize(
|
||||
annotations=ft_annotations,
|
||||
seqlen=len(sequence),
|
||||
annotations=ft_annotations, seqlen=len(sequence)
|
||||
)
|
||||
function_token_ids = function_tokens_tokenizer.encode(
|
||||
function_tokens, add_special_tokens=add_special_tokens
|
||||
@@ -175,10 +174,7 @@ def decode_residue_annotation_tokens(
|
||||
annotation = FunctionAnnotation(label=label, start=loc, end=loc)
|
||||
annotations.append(annotation)
|
||||
|
||||
annotations = merge_annotations(
|
||||
annotations,
|
||||
merge_gap_max=annotation_gap_merge_max,
|
||||
)
|
||||
annotations = merge_annotations(annotations, merge_gap_max=annotation_gap_merge_max)
|
||||
|
||||
# Drop very small annotations.
|
||||
if annotation_min_length is not None:
|
||||
|
||||
@@ -127,11 +127,7 @@ class InterPro:
|
||||
col in df.columns for col in ["ENTRY_AC", "ENTRY_TYPE", "ENTRY_NAME"]
|
||||
)
|
||||
df.rename(
|
||||
columns={
|
||||
"ENTRY_AC": "id",
|
||||
"ENTRY_TYPE": "type",
|
||||
"ENTRY_NAME": "name",
|
||||
},
|
||||
columns={"ENTRY_AC": "id", "ENTRY_TYPE": "type", "ENTRY_NAME": "name"},
|
||||
inplace=True,
|
||||
)
|
||||
df["type"] = df.type.str.upper().apply(
|
||||
|
||||
@@ -50,8 +50,7 @@ class TFIDFModel:
|
||||
values /= np.linalg.norm(values)
|
||||
|
||||
return sparse.csr_matrix(
|
||||
(values, (np.zeros_like(indices), indices)),
|
||||
shape=(1, len(self.vocabulary)),
|
||||
(values, (np.zeros_like(indices), indices)), shape=(1, len(self.vocabulary))
|
||||
)
|
||||
|
||||
def decode(self, vec: sparse.csr_matrix) -> list[str]:
|
||||
|
||||
@@ -128,9 +128,7 @@ def iterative_sampling_raw(
|
||||
|
||||
|
||||
def _make_masked_inputs(
|
||||
track: str,
|
||||
sequence_length: int,
|
||||
tokenizers: TokenizerCollectionProtocol,
|
||||
track: str, sequence_length: int, tokenizers: TokenizerCollectionProtocol
|
||||
):
|
||||
get_tokenizer: Callable[[str], EsmTokenizerBase] = lambda s: getattr(tokenizers, s)
|
||||
|
||||
@@ -190,8 +188,7 @@ def _stack_protein_tensors(
|
||||
o,
|
||||
fn,
|
||||
stack_variable_length_tensors(
|
||||
sequences=tensors,
|
||||
constant_value=mask_token_id,
|
||||
sequences=tensors, constant_value=mask_token_id
|
||||
),
|
||||
)
|
||||
|
||||
@@ -240,6 +237,11 @@ def _get_iterative_sampling_mask_for_prompt_and_step(
|
||||
shape = tokens.shape
|
||||
B, L = shape[0], shape[1]
|
||||
|
||||
# TODO: figure out why we want this function to work with
|
||||
# _BatchedESMProteinTensor in the first place. Logics below
|
||||
# don't really work for batched tensors.
|
||||
assert B == 1
|
||||
|
||||
sampling_mask = torch.ones((B, L), dtype=torch.bool, device=device)
|
||||
sampling_mask[:, 0] = False # BOS
|
||||
# EOS and all padding tokens.
|
||||
@@ -248,9 +250,7 @@ def _get_iterative_sampling_mask_for_prompt_and_step(
|
||||
).to(device)
|
||||
|
||||
is_mask = _get_masked_positions(
|
||||
track_to_sample,
|
||||
tokens,
|
||||
getattr(tokenizers, track_to_sample).mask_token_id,
|
||||
track_to_sample, tokens, getattr(tokenizers, track_to_sample).mask_token_id
|
||||
)
|
||||
if not is_mask.any().item():
|
||||
raise ValueError(f"Cannot sample {config.track} when input has no masks.")
|
||||
@@ -273,27 +273,36 @@ def _get_iterative_sampling_mask_for_prompt_and_step(
|
||||
).int()
|
||||
num_to_sample = still_masked - num_tokens_masked_after_this_step
|
||||
|
||||
track_entropy: torch.Tensor = getattr(entropy, track_to_sample).to(
|
||||
device
|
||||
) # (B, L) or (B, L, D)
|
||||
if config.strategy == "entropy":
|
||||
track_entropy: torch.Tensor = getattr(entropy, track_to_sample).to(
|
||||
device
|
||||
) # (B, L) or (B, L, D)
|
||||
|
||||
if track_to_sample == "function":
|
||||
track_entropy = track_entropy.sum(-1) # (B, L, D) -> (B, L)
|
||||
if track_to_sample == "function":
|
||||
track_entropy = track_entropy.sum(-1) # (B, L, D) -> (B, L)
|
||||
|
||||
track_entropy = track_entropy.masked_fill(
|
||||
~sampling_mask, torch.finfo(track_entropy.dtype).max
|
||||
)
|
||||
_, indices = track_entropy.topk(num_to_sample, dim=-1, largest=False)
|
||||
is_top_k = torch.zeros((B, L), dtype=torch.bool, device=device).scatter(
|
||||
1, indices, True
|
||||
)
|
||||
where_to_sample = sampling_mask & is_top_k
|
||||
track_entropy = track_entropy.masked_fill(
|
||||
~sampling_mask, torch.finfo(track_entropy.dtype).max
|
||||
)
|
||||
_, indices = track_entropy.topk(num_to_sample, dim=-1, largest=False)
|
||||
is_top_k = torch.zeros((B, L), dtype=torch.bool, device=device).scatter(
|
||||
1, indices, True
|
||||
)
|
||||
where_to_sample = sampling_mask & is_top_k
|
||||
elif config.strategy == "random":
|
||||
# Skip B since we know there is only 1 prompt here.
|
||||
_, masked_indices = sampling_mask.nonzero(as_tuple=True)
|
||||
# Random shuffle the masked indices then select the first num_to_sample.
|
||||
rnd_indices = masked_indices[torch.randperm(len(masked_indices))][
|
||||
:num_to_sample
|
||||
]
|
||||
rnd_mask = torch.zeros_like(sampling_mask)
|
||||
rnd_mask[:, rnd_indices] = True
|
||||
where_to_sample = sampling_mask & rnd_mask
|
||||
|
||||
if track_to_sample == "function":
|
||||
where_to_sample = where_to_sample.unsqueeze(-1).expand(
|
||||
B,
|
||||
L,
|
||||
tokenizers.function.depth,
|
||||
B, L, tokenizers.function.depth
|
||||
) # (B, L) -> (B, L, D)
|
||||
|
||||
return where_to_sample
|
||||
@@ -316,6 +325,11 @@ def _get_non_special_tokens(
|
||||
return int(torch.sum(mask).item())
|
||||
|
||||
|
||||
def _get_annealed_temperature(step: int, num_steps: int, initial_temperature: float):
|
||||
step_ratio = step / max(1, (num_steps - 1))
|
||||
return max(initial_temperature - step_ratio, 0.001) ** 2
|
||||
|
||||
|
||||
def iterative_sampling_tokens(
|
||||
client: ESM3InferenceClient,
|
||||
input_tokens: list[ESMProteinTensor],
|
||||
@@ -345,9 +359,7 @@ def iterative_sampling_tokens(
|
||||
num_sampling_steps = _get_non_special_tokens(protein, tokenizers)
|
||||
else:
|
||||
masked = _get_masked_positions(
|
||||
track,
|
||||
getattr(protein, track),
|
||||
getattr(tokenizers, track).mask_token_id,
|
||||
track, getattr(protein, track), getattr(tokenizers, track).mask_token_id
|
||||
)
|
||||
num_sampling_steps = torch.sum(masked).item()
|
||||
|
||||
@@ -365,10 +377,7 @@ def iterative_sampling_tokens(
|
||||
|
||||
# Now stack the list to make a single batched ESMProteinTensor.
|
||||
batched_tokens = _stack_protein_tensors(
|
||||
sampled_tokens,
|
||||
sequence_lengths,
|
||||
tokenizers,
|
||||
devices.pop(),
|
||||
sampled_tokens, sequence_lengths, tokenizers, devices.pop()
|
||||
)
|
||||
|
||||
# Remember sampled prompts that has somehow errored out.
|
||||
@@ -418,9 +427,18 @@ def iterative_sampling_tokens(
|
||||
len(per_prompt_cur_sampled),
|
||||
)
|
||||
|
||||
# Handle temperature annealing, since _sample_per_prompt() doesn't have
|
||||
# the concept of decoding steps.
|
||||
if config.temperature_annealing:
|
||||
temperature = _get_annealed_temperature(
|
||||
t, config.num_steps, config.temperature
|
||||
)
|
||||
else:
|
||||
temperature = config.temperature
|
||||
|
||||
track_sample_config = SamplingTrackConfig()
|
||||
track_sample_config.invalid_ids = config.invalid_ids
|
||||
track_sample_config.temperature = config.temperature
|
||||
track_sample_config.temperature = temperature
|
||||
track_sample_config.top_p = config.top_p
|
||||
sampling_config = SamplingConfig(**{config.track: track_sample_config}) # type: ignore
|
||||
|
||||
@@ -486,7 +504,7 @@ def iterative_sampling_tokens(
|
||||
setattr(outputs, "coordinates", getattr(inputs, "coordinates"))
|
||||
# Maybe restore all the other fields.
|
||||
for f in attr.fields(SamplingConfig):
|
||||
if "embedding" in f.name:
|
||||
if "embedding" in f.name or f.name == "return_hidden_states":
|
||||
continue
|
||||
if f.name != config.track:
|
||||
setattr(outputs, f.name, getattr(inputs, f.name))
|
||||
@@ -494,10 +512,7 @@ def iterative_sampling_tokens(
|
||||
return output_tokens
|
||||
|
||||
|
||||
def _batch_forward(
|
||||
client: ESM3InferenceClient,
|
||||
protein: _BatchedESMProteinTensor,
|
||||
):
|
||||
def _batch_forward(client: ESM3InferenceClient, protein: _BatchedESMProteinTensor):
|
||||
# Forward pass
|
||||
return client.logits(
|
||||
protein,
|
||||
|
||||
@@ -97,9 +97,7 @@ def test_num_decoding_steps_more_than_mask_tokens_batched(esm3_remote_inference_
|
||||
|
||||
@pytest.mark.gpu
|
||||
def test_encode_chainbreak_token(esm3_remote_inference_client):
|
||||
protein = esm3_remote_inference_client.encode(
|
||||
ESMProtein(sequence="MSTNP|KPQKK"),
|
||||
)
|
||||
protein = esm3_remote_inference_client.encode(ESMProtein(sequence="MSTNP|KPQKK"))
|
||||
assert isinstance(protein, ESMProteinTensor)
|
||||
assert protein.sequence is not None
|
||||
assert (
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import math
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import ContextManager, Sequence, TypeVar
|
||||
@@ -226,8 +225,7 @@ def merge_ranges(ranges: list[range], merge_gap_max: int | None = None) -> list[
|
||||
|
||||
|
||||
def merge_annotations(
|
||||
annotations: list[FunctionAnnotation],
|
||||
merge_gap_max: int | None = None,
|
||||
annotations: list[FunctionAnnotation], merge_gap_max: int | None = None
|
||||
) -> list[FunctionAnnotation]:
|
||||
"""Merges annotations into non-overlapping segments.
|
||||
|
||||
@@ -256,42 +254,24 @@ def merge_annotations(
|
||||
return merged
|
||||
|
||||
|
||||
def list_nan_to_none(l: list) -> list:
|
||||
if l is None:
|
||||
return None # type: ignore
|
||||
elif isinstance(l, float):
|
||||
return None if math.isnan(l) else l # type: ignore
|
||||
elif isinstance(l, list):
|
||||
return [list_nan_to_none(x) for x in l]
|
||||
else:
|
||||
# Don't go into other structures.
|
||||
return l
|
||||
|
||||
|
||||
def list_none_to_nan(l: list) -> list:
|
||||
if l is None:
|
||||
return math.nan # type: ignore
|
||||
elif isinstance(l, list):
|
||||
return [list_none_to_nan(x) for x in l]
|
||||
else:
|
||||
return l
|
||||
|
||||
|
||||
def maybe_tensor(x, convert_none_to_nan: bool = False) -> torch.Tensor | None:
|
||||
if x is None:
|
||||
return None
|
||||
if convert_none_to_nan:
|
||||
x = list_none_to_nan(x)
|
||||
x = np.array(x, copy=False, dtype=np.float32)
|
||||
x = np.where(x is None, np.nan, x)
|
||||
return torch.tensor(x)
|
||||
|
||||
|
||||
def maybe_list(x, convert_nan_to_none: bool = False) -> list | None:
|
||||
if x is None:
|
||||
return None
|
||||
x = x.tolist()
|
||||
if convert_nan_to_none:
|
||||
x = list_nan_to_none(x)
|
||||
return x
|
||||
if not convert_nan_to_none:
|
||||
return x.tolist()
|
||||
nan_mask = torch.isnan(x)
|
||||
np_arr = x.cpu().numpy().astype(object)
|
||||
np_arr[nan_mask.cpu().numpy()] = None
|
||||
return np_arr.tolist()
|
||||
|
||||
|
||||
def huggingfacehub_login():
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Tests for misc.py"""
|
||||
|
||||
|
||||
from esm.utils.misc import merge_annotations
|
||||
from esm.utils.types import FunctionAnnotation
|
||||
|
||||
|
||||
@@ -22,16 +22,30 @@ from esm.utils.constants.esm3 import (
|
||||
SASA_DISCRETIZATION_BOUNDARIES,
|
||||
)
|
||||
|
||||
# Number of dimensions for each protein tensor field without the batch dimension.
|
||||
_DIMS: dict[str, int] = {
|
||||
"sequence": 1,
|
||||
"structure": 1,
|
||||
"secondary_structure": 1,
|
||||
"sasa": 1,
|
||||
"function": 2,
|
||||
"residue_annotations": 2,
|
||||
"coordinates": 3,
|
||||
}
|
||||
|
||||
def _non_batched_dims(k: str, v: torch.Tensor):
|
||||
match k:
|
||||
case "sequence":
|
||||
return 1
|
||||
case "structure":
|
||||
if v.is_floating_point():
|
||||
# This is the one hot soft structure token.
|
||||
return 2
|
||||
else:
|
||||
# This is the normal int structure token.
|
||||
return 1
|
||||
case "secondary_structure":
|
||||
return 1
|
||||
case "sasa":
|
||||
return 1
|
||||
case "function":
|
||||
return 2
|
||||
case "residue_annotations":
|
||||
return 2
|
||||
case "coordinates":
|
||||
return 3
|
||||
case _:
|
||||
raise ValueError(f"Unknown dim for track {k}")
|
||||
|
||||
|
||||
class _BatchedESMProteinTensor(ESMProteinTensor):
|
||||
@@ -52,7 +66,7 @@ class _BatchedESMProteinTensor(ESMProteinTensor):
|
||||
|
||||
def __len__(self) -> int:
|
||||
def get_len(k, v) -> int:
|
||||
assert len(v.shape) == _DIMS[k] + 1
|
||||
assert len(v.shape) == _non_batched_dims(k, v) + 1
|
||||
return v.size(1)
|
||||
|
||||
l = self._detect_attribute(get_len, "length")
|
||||
@@ -61,18 +75,14 @@ class _BatchedESMProteinTensor(ESMProteinTensor):
|
||||
@property
|
||||
def batch_size(self) -> int:
|
||||
def get_batch_size(k, v) -> int:
|
||||
assert len(v.shape) == _DIMS[k] + 1
|
||||
assert len(v.shape) == _non_batched_dims(k, v) + 1
|
||||
return v.size(0)
|
||||
|
||||
d = self._detect_attribute(get_batch_size, "batch size")
|
||||
assert d is not None
|
||||
return d
|
||||
|
||||
def slice(
|
||||
self,
|
||||
i: int,
|
||||
sequence_len: int | None = None,
|
||||
) -> ESMProteinTensor:
|
||||
def slice(self, i: int, sequence_len: int | None = None) -> ESMProteinTensor:
|
||||
def _maybe_slice(x: torch.Tensor | None):
|
||||
if x is None:
|
||||
return None
|
||||
@@ -130,8 +140,7 @@ def get_default_sampling_config(
|
||||
|
||||
|
||||
def validate_sampling_config(
|
||||
sampling_config: SamplingConfig,
|
||||
on_invalid: Literal["raise", "warn"] = "warn",
|
||||
sampling_config: SamplingConfig, on_invalid: Literal["raise", "warn"] = "warn"
|
||||
):
|
||||
# Check that all tracks have topk_logprobs less or equal to MAX_TOP_K
|
||||
for track in attr.fields(SamplingConfig):
|
||||
@@ -288,10 +297,7 @@ def sample_sasa_logits(
|
||||
return sasa_value
|
||||
|
||||
|
||||
def top_p_logits(
|
||||
logits: torch.Tensor,
|
||||
top_p: float | torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
def top_p_logits(logits: torch.Tensor, top_p: float | torch.Tensor) -> torch.Tensor:
|
||||
top_p = _tensorize_like(top_p, logits)
|
||||
|
||||
batch_dims = logits.size()[:-1]
|
||||
@@ -320,9 +326,7 @@ def _tensorize_like(value: int | float | torch.Tensor, logits: torch.Tensor):
|
||||
|
||||
|
||||
def get_sampling_mask(
|
||||
tokens: torch.Tensor,
|
||||
sampling_track_config: SamplingTrackConfig,
|
||||
mask_idx: int,
|
||||
tokens: torch.Tensor, sampling_track_config: SamplingTrackConfig, mask_idx: int
|
||||
):
|
||||
# Do not sample at BOS and EOS tokens
|
||||
sampling_mask = torch.ones_like(tokens, dtype=torch.bool) # (B, L, )
|
||||
|
||||
@@ -31,9 +31,7 @@ def test_sample_logits():
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
sampled = sample_logits(
|
||||
logits=torch.randn((8, 4096)),
|
||||
temperature=0.0,
|
||||
valid_ids=[],
|
||||
logits=torch.randn((8, 4096)), temperature=0.0, valid_ids=[]
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -12,15 +12,12 @@ from esm.utils.misc import fp32_autocast_context
|
||||
@T.runtime_checkable
|
||||
class Rotation(T.Protocol):
|
||||
@classmethod
|
||||
def identity(cls, shape: tuple[int, ...], **tensor_kwargs) -> Self:
|
||||
...
|
||||
def identity(cls, shape: tuple[int, ...], **tensor_kwargs) -> Self: ...
|
||||
|
||||
@classmethod
|
||||
def random(cls, shape: tuple[int, ...], **tensor_kwargs) -> Self:
|
||||
...
|
||||
def random(cls, shape: tuple[int, ...], **tensor_kwargs) -> Self: ...
|
||||
|
||||
def __getitem__(self, idx: T.Any) -> Self:
|
||||
...
|
||||
def __getitem__(self, idx: T.Any) -> Self: ...
|
||||
|
||||
@property
|
||||
def tensor(self) -> torch.Tensor:
|
||||
@@ -35,8 +32,7 @@ class Rotation(T.Protocol):
|
||||
# This means that 1x4 quaternions are treated as size (1,) for example
|
||||
...
|
||||
|
||||
def as_matrix(self) -> RotationMatrix:
|
||||
...
|
||||
def as_matrix(self) -> RotationMatrix: ...
|
||||
|
||||
def compose(self, other: Self) -> Self:
|
||||
# To be safe, we force users to explicitly convert between rotation types.
|
||||
@@ -50,8 +46,7 @@ class Rotation(T.Protocol):
|
||||
# rotates points by this rotation object
|
||||
...
|
||||
|
||||
def invert(self) -> Self:
|
||||
...
|
||||
def invert(self) -> Self: ...
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
@@ -194,10 +189,7 @@ class Affine3D:
|
||||
|
||||
def __getitem__(self, idx: T.Any) -> "Affine3D":
|
||||
indices = (idx,) if isinstance(idx, int) or idx is None else tuple(idx)
|
||||
return Affine3D(
|
||||
trans=self.trans[indices + (slice(None),)],
|
||||
rot=self.rot[idx],
|
||||
)
|
||||
return Affine3D(trans=self.trans[indices + (slice(None),)], rot=self.rot[idx])
|
||||
|
||||
@property
|
||||
def shape(self) -> torch.Size:
|
||||
|
||||
@@ -17,8 +17,7 @@ class Alignable(Protocol):
|
||||
# Trick to detect whether an object is a dataclass
|
||||
__dataclass_fields__: ClassVar[dict[str, Field[Any]]]
|
||||
|
||||
def __len__(self) -> int:
|
||||
...
|
||||
def __len__(self) -> int: ...
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Alignable)
|
||||
|
||||
@@ -139,18 +139,13 @@ def compute_gdt_ts(
|
||||
"""
|
||||
if atom_exists_mask is None:
|
||||
atom_exists_mask = torch.isfinite(target).all(dim=-1)
|
||||
(
|
||||
centered_mobile,
|
||||
_,
|
||||
centered_target,
|
||||
_,
|
||||
rotation_matrix,
|
||||
_,
|
||||
) = compute_alignment_tensors(
|
||||
mobile=mobile,
|
||||
target=target,
|
||||
atom_exists_mask=atom_exists_mask,
|
||||
sequence_id=sequence_id,
|
||||
(centered_mobile, _, centered_target, _, rotation_matrix, _) = (
|
||||
compute_alignment_tensors(
|
||||
mobile=mobile,
|
||||
target=target,
|
||||
atom_exists_mask=atom_exists_mask,
|
||||
sequence_id=sequence_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Apply transformation to centered structure
|
||||
|
||||
@@ -43,10 +43,7 @@ def get_protein_normalization_frame(coords: Tensor) -> Affine3D:
|
||||
Affine3D: tensor of Affine3D frame
|
||||
"""
|
||||
bb_coords = index_by_atom_name(coords, ["N", "CA", "C"], dim=-2)
|
||||
coord_mask = torch.all(
|
||||
torch.all(torch.isfinite(bb_coords), dim=-1),
|
||||
dim=-1,
|
||||
)
|
||||
coord_mask = torch.all(torch.all(torch.isfinite(bb_coords), dim=-1), dim=-1)
|
||||
|
||||
average_position_per_n_ca_c = bb_coords.masked_fill(
|
||||
~coord_mask[..., None, None], 0
|
||||
|
||||
@@ -49,11 +49,7 @@ def compute_predicted_aligned_error(
|
||||
|
||||
|
||||
@torch.no_grad
|
||||
def compute_tm(
|
||||
logits: torch.Tensor,
|
||||
aa_mask: torch.Tensor,
|
||||
max_bin: float = 31.0,
|
||||
):
|
||||
def compute_tm(logits: torch.Tensor, aa_mask: torch.Tensor, max_bin: float = 31.0):
|
||||
square_mask = _compute_pae_masks(aa_mask)
|
||||
seqlens = aa_mask.sum(-1, keepdim=True)
|
||||
bins = _pae_bins(max_bin, logits.shape[-1], logits.device)
|
||||
|
||||
@@ -229,8 +229,7 @@ class ProteinChain:
|
||||
return buf.getvalue()
|
||||
|
||||
def to_structure_encoder_inputs(
|
||||
self,
|
||||
should_normalize_coordinates: bool = True,
|
||||
self, should_normalize_coordinates: bool = True
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
coords = torch.tensor(self.atom37_positions, dtype=torch.float32)
|
||||
plddt = torch.tensor(self.confidence, dtype=torch.float32)
|
||||
@@ -494,9 +493,7 @@ class ProteinChain:
|
||||
|
||||
@classmethod
|
||||
def from_backbone_atom_coordinates(
|
||||
cls,
|
||||
backbone_atom_coordinates: np.ndarray | torch.Tensor,
|
||||
**kwargs,
|
||||
cls, backbone_atom_coordinates: np.ndarray | torch.Tensor, **kwargs
|
||||
):
|
||||
"""Create a ProteinChain from a set of backbone atom coordinates.
|
||||
|
||||
@@ -529,10 +526,7 @@ class ProteinChain:
|
||||
)
|
||||
atom37_positions[:, :3, :] = backbone_atom_coordinates
|
||||
|
||||
return cls.from_atom37(
|
||||
atom37_positions=atom37_positions,
|
||||
**kwargs,
|
||||
)
|
||||
return cls.from_atom37(atom37_positions=atom37_positions, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_pdb(
|
||||
@@ -586,22 +580,13 @@ class ProteinChain:
|
||||
num_res = len(sequence)
|
||||
|
||||
atom_positions = np.full(
|
||||
[num_res, RC.atom_type_num, 3],
|
||||
np.nan,
|
||||
dtype=np.float32,
|
||||
)
|
||||
atom_mask = np.full(
|
||||
[num_res, RC.atom_type_num],
|
||||
False,
|
||||
dtype=bool,
|
||||
[num_res, RC.atom_type_num, 3], np.nan, dtype=np.float32
|
||||
)
|
||||
atom_mask = np.full([num_res, RC.atom_type_num], False, dtype=bool)
|
||||
residue_index = np.full([num_res], -1, dtype=np.int64)
|
||||
insertion_code = np.full([num_res], "", dtype="<U4")
|
||||
|
||||
confidence = np.ones(
|
||||
[num_res],
|
||||
dtype=np.float32,
|
||||
)
|
||||
confidence = np.ones([num_res], dtype=np.float32)
|
||||
|
||||
for i, res in enumerate(bs.residue_iter(atom_array)):
|
||||
chain = atom_array[atom_array.chain_id == chain_id]
|
||||
@@ -639,20 +624,14 @@ class ProteinChain:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_rcsb(
|
||||
cls,
|
||||
pdb_id: str,
|
||||
chain_id: str = "detect",
|
||||
):
|
||||
def from_rcsb(cls, pdb_id: str, chain_id: str = "detect"):
|
||||
"""Fetch a protein chain from the RCSB PDB database."""
|
||||
f: io.StringIO = rcsb.fetch(pdb_id, "pdb") # type: ignore
|
||||
return cls.from_pdb(f, chain_id=chain_id, id=pdb_id)
|
||||
|
||||
@classmethod
|
||||
def from_atomarray(
|
||||
cls,
|
||||
atom_array: bs.AtomArray,
|
||||
id: str | None = None,
|
||||
cls, atom_array: bs.AtomArray, id: str | None = None
|
||||
) -> "ProteinChain":
|
||||
"""A simple converter from bs.AtomArray -> ProteinChain.
|
||||
Uses PDB file format as intermediate."""
|
||||
|
||||
@@ -254,10 +254,7 @@ class ProteinComplex:
|
||||
return ProteinComplex.from_chains(chains)
|
||||
|
||||
@classmethod
|
||||
def from_rcsb(
|
||||
cls,
|
||||
pdb_id: str,
|
||||
):
|
||||
def from_rcsb(cls, pdb_id: str):
|
||||
"""Fetch a protein complex from the RCSB PDB database."""
|
||||
f: io.StringIO = rcsb.fetch(pdb_id, "pdb") # type: ignore
|
||||
return cls.from_pdb(f, id=pdb_id)
|
||||
@@ -345,10 +342,7 @@ class ProteinComplex:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_chains(
|
||||
cls,
|
||||
chains: Sequence[ProteinChain],
|
||||
):
|
||||
def from_chains(cls, chains: Sequence[ProteinChain]):
|
||||
if not chains:
|
||||
raise ValueError(
|
||||
"Cannot create a ProteinComplex from an empty list of chains"
|
||||
|
||||
@@ -254,10 +254,7 @@ def compute_affine_and_rmsd(
|
||||
# Apply transformation to centered structure to compute rmsd
|
||||
rotated_mobile = torch.matmul(centered_mobile, rotation_matrix)
|
||||
avg_rmsd = compute_rmsd_no_alignment(
|
||||
rotated_mobile,
|
||||
centered_target,
|
||||
num_valid_atoms,
|
||||
reduction="batch",
|
||||
rotated_mobile, centered_target, num_valid_atoms, reduction="batch"
|
||||
)
|
||||
|
||||
return affine, avg_rmsd
|
||||
|
||||
@@ -116,19 +116,10 @@ def create_function_annotator(
|
||||
)
|
||||
|
||||
delete_button = widgets.Button(
|
||||
description="Delete",
|
||||
tooltip="Delete this annotation",
|
||||
icon="trash",
|
||||
)
|
||||
entry = widgets.HBox(
|
||||
[
|
||||
delete_button,
|
||||
widgets.Label(value=function_str),
|
||||
]
|
||||
)
|
||||
delete_button.on_click(
|
||||
on_delete_click,
|
||||
description="Delete", tooltip="Delete this annotation", icon="trash"
|
||||
)
|
||||
entry = widgets.HBox([delete_button, widgets.Label(value=function_str)])
|
||||
delete_button.on_click(on_delete_click)
|
||||
entries.children += (entry,)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -150,8 +150,7 @@ def create_sequence_results_page(
|
||||
sequence_items = []
|
||||
for item in items:
|
||||
copy_to_prompt_button = widgets.Button(
|
||||
description="Copy to Prompt",
|
||||
disabled=copy_to_prompt_callback is None,
|
||||
description="Copy to Prompt", disabled=copy_to_prompt_callback is None
|
||||
)
|
||||
if copy_to_prompt_callback:
|
||||
copy_to_prompt_button.on_click(
|
||||
@@ -190,8 +189,7 @@ def create_sasa_results_page(
|
||||
sasa_items = []
|
||||
for item in items:
|
||||
copy_to_prompt_button = widgets.Button(
|
||||
description="Copy to Prompt",
|
||||
disabled=copy_to_prompt_callback is None,
|
||||
description="Copy to Prompt", disabled=copy_to_prompt_callback is None
|
||||
)
|
||||
if copy_to_prompt_callback:
|
||||
copy_to_prompt_button.on_click(lambda b: copy_to_prompt_callback(item.sasa))
|
||||
@@ -201,11 +199,7 @@ def create_sasa_results_page(
|
||||
print("Solvent Accessible Surface Area (SASA) is not available.")
|
||||
else:
|
||||
sasa = [s or 0 for s in item.sasa]
|
||||
draw_data_array(
|
||||
output,
|
||||
data_array=sasa,
|
||||
cmap="Reds",
|
||||
)
|
||||
draw_data_array(output, data_array=sasa, cmap="Reds")
|
||||
|
||||
if copy_to_prompt_callback:
|
||||
sasa_items.append(
|
||||
@@ -227,8 +221,7 @@ def create_secondary_structure_results_page(
|
||||
ss_items = []
|
||||
for item in items:
|
||||
copy_to_prompt_button = widgets.Button(
|
||||
description="Copy to Prompt",
|
||||
disabled=copy_to_prompt_callback is None,
|
||||
description="Copy to Prompt", disabled=copy_to_prompt_callback is None
|
||||
)
|
||||
if copy_to_prompt_callback:
|
||||
copy_to_prompt_button.on_click(
|
||||
@@ -292,8 +285,7 @@ def create_structure_results_page(
|
||||
else:
|
||||
ptm_label = widgets.Label(value=f"pTM: {item.ptm.item():.2f}")
|
||||
copy_to_prompt_button = widgets.Button(
|
||||
description="Copy to Prompt",
|
||||
disabled=copy_to_prompt_callback is None,
|
||||
description="Copy to Prompt", disabled=copy_to_prompt_callback is None
|
||||
)
|
||||
if copy_to_prompt_callback:
|
||||
copy_to_prompt_button.on_click(
|
||||
@@ -351,8 +343,7 @@ def create_structure_results_page(
|
||||
header = widgets.HBox([download_pdb_button, ptm_label])
|
||||
|
||||
grid[row, col] = widgets.VBox(
|
||||
[header, output],
|
||||
layout={"border": "1px solid gray"},
|
||||
[header, output], layout={"border": "1px solid gray"}
|
||||
)
|
||||
return grid
|
||||
|
||||
@@ -364,8 +355,7 @@ def create_function_annotations_results_page(
|
||||
function_items = []
|
||||
for item in items:
|
||||
copy_to_prompt_button = widgets.Button(
|
||||
description="Copy to Prompt",
|
||||
disabled=copy_to_prompt_callback is None,
|
||||
description="Copy to Prompt", disabled=copy_to_prompt_callback is None
|
||||
)
|
||||
if copy_to_prompt_callback:
|
||||
copy_to_prompt_button.on_click(
|
||||
@@ -387,8 +377,7 @@ def create_function_annotations_results_page(
|
||||
)
|
||||
else:
|
||||
image = draw_function_annotations(
|
||||
interpro_annotations,
|
||||
sequence_length=len(item),
|
||||
interpro_annotations, sequence_length=len(item)
|
||||
)
|
||||
if copy_to_prompt_callback:
|
||||
content = widgets.VBox(
|
||||
|
||||
@@ -155,11 +155,7 @@ def get_secondary_structure(protein_chain: ProteinChain) -> Sequence[int]:
|
||||
|
||||
|
||||
def get_ss3_categories():
|
||||
return [
|
||||
"Coil (C)",
|
||||
"Alpha helix (H)",
|
||||
"Beta strand (E)",
|
||||
]
|
||||
return ["Coil (C)", "Alpha helix (H)", "Beta strand (E)"]
|
||||
|
||||
|
||||
def ss3_plot_index_to_letter(ss3_index: int) -> str:
|
||||
|
||||
@@ -124,9 +124,9 @@ def create_sequence_prompt_selector(
|
||||
r, g, b, a = hex_to_rgba_tuple(combined_color)
|
||||
a = 0.5 # Set alpha to 0.5
|
||||
combined_color = rgba_tuple_to_rgba_html_string((r, g, b, a))
|
||||
highlighted_line[
|
||||
i
|
||||
] = f'<span style="background-color:{combined_color}">{highlighted_line[i]}</span>'
|
||||
highlighted_line[i] = (
|
||||
f'<span style="background-color:{combined_color}">{highlighted_line[i]}</span>'
|
||||
)
|
||||
highlighted_lines.append("".join(highlighted_line))
|
||||
|
||||
return "<br>".join(highlighted_lines)
|
||||
|
||||
@@ -112,10 +112,9 @@ def create_structure_prompt_selector(
|
||||
).items():
|
||||
selected_ranges = tuple(selected_ranges) # Convert to hashable
|
||||
if selected_ranges in contact_map_selection_cache:
|
||||
(
|
||||
(x_start, x_end),
|
||||
(y_start, y_end),
|
||||
) = contact_map_selection_cache[selected_ranges]
|
||||
((x_start, x_end), (y_start, y_end)) = contact_map_selection_cache[
|
||||
selected_ranges
|
||||
]
|
||||
rect = Rectangle(
|
||||
(x_start - 0.5, max_y - y_end - 1.5),
|
||||
x_end - x_start + 1,
|
||||
|
||||
@@ -24,7 +24,5 @@ def get_forge_client(model_name: str) -> ESM3InferenceClient:
|
||||
"Forge API key not found. Please set the ESM_API_KEY environment variable."
|
||||
)
|
||||
return ESM3ForgeInferenceClient(
|
||||
model=model_name,
|
||||
url="https://forge.evolutionaryscale.ai",
|
||||
token=forge_token,
|
||||
model=model_name, url="https://forge.evolutionaryscale.ai", token=forge_token
|
||||
)
|
||||
|
||||
@@ -110,8 +110,7 @@ def draw_data_array(
|
||||
else:
|
||||
legend_patches = [
|
||||
patches.Patch(
|
||||
color=rgb_colors[category_to_index[category]],
|
||||
label=category,
|
||||
color=rgb_colors[category_to_index[category]], label=category
|
||||
)
|
||||
for category in categories
|
||||
]
|
||||
|
||||
@@ -26,9 +26,7 @@ def use_backend(backend):
|
||||
|
||||
|
||||
def draw_function_annotations(
|
||||
annotations: list[FunctionAnnotation],
|
||||
sequence_length: int,
|
||||
interpro_=InterPro(),
|
||||
annotations: list[FunctionAnnotation], sequence_length: int, interpro_=InterPro()
|
||||
) -> widgets.Image:
|
||||
cmap = colormaps["tab10"]
|
||||
colors = [cmap(i) for i in range(len(InterProEntryType))]
|
||||
@@ -63,9 +61,7 @@ def draw_function_annotations(
|
||||
with use_backend("agg"):
|
||||
fig, ax = plt.subplots()
|
||||
record = GraphicRecord(
|
||||
sequence=None,
|
||||
sequence_length=sequence_length,
|
||||
features=features,
|
||||
sequence=None, sequence_length=sequence_length, features=features
|
||||
)
|
||||
record.plot(ax=ax, plot_sequence=False)
|
||||
fig.savefig(buf, format="png", dpi=200, bbox_inches="tight")
|
||||
|
||||
@@ -19,8 +19,7 @@ def draw_protein_structure(
|
||||
|
||||
for start, end, color in highlighted_ranges:
|
||||
view.setStyle(
|
||||
{"resi": str(start) + "-" + str(end)},
|
||||
{"cartoon": {"color": color}},
|
||||
{"resi": str(start) + "-" + str(end)}, {"cartoon": {"color": color}}
|
||||
)
|
||||
|
||||
view.zoomTo()
|
||||
|
||||
@@ -8,18 +8,13 @@ PDB_INDEX = "PDB index"
|
||||
PDB_INDEX_SUFFIX = "[PDB Index]"
|
||||
|
||||
|
||||
def get_pdb_index_min_max(
|
||||
protein_chain: ProteinChain,
|
||||
) -> tuple[int, int]:
|
||||
def get_pdb_index_min_max(protein_chain: ProteinChain) -> tuple[int, int]:
|
||||
residue_index = protein_chain.residue_index
|
||||
valid_residue_index = residue_index[residue_index != -1]
|
||||
return min(valid_residue_index), max(valid_residue_index)
|
||||
|
||||
|
||||
def pdb_index_to_zero_index(
|
||||
residue_index: int,
|
||||
protein_chain: ProteinChain,
|
||||
) -> int:
|
||||
def pdb_index_to_zero_index(residue_index: int, protein_chain: ProteinChain) -> int:
|
||||
# Find the first position equal to residue_index
|
||||
pos = np.argwhere(residue_index == protein_chain.residue_index)
|
||||
if len(pos) == 0:
|
||||
@@ -27,16 +22,12 @@ def pdb_index_to_zero_index(
|
||||
return pos[0][0]
|
||||
|
||||
|
||||
def zero_index_to_pdb_index(
|
||||
zero_index: int,
|
||||
protein_chain: ProteinChain,
|
||||
) -> int:
|
||||
def zero_index_to_pdb_index(zero_index: int, protein_chain: ProteinChain) -> int:
|
||||
return protein_chain.residue_index[zero_index]
|
||||
|
||||
|
||||
def zero_range_to_pdb_range(
|
||||
zero_range: tuple[int, int],
|
||||
protein_chain: ProteinChain,
|
||||
zero_range: tuple[int, int], protein_chain: ProteinChain
|
||||
) -> tuple[int, int]:
|
||||
return (
|
||||
zero_index_to_pdb_index(zero_range[0], protein_chain),
|
||||
@@ -45,8 +36,7 @@ def zero_range_to_pdb_range(
|
||||
|
||||
|
||||
def pdb_range_to_zero_range(
|
||||
pdb_range: tuple[int, int],
|
||||
protein_chain: ProteinChain,
|
||||
pdb_range: tuple[int, int], protein_chain: ProteinChain
|
||||
) -> tuple[int, int]:
|
||||
return (
|
||||
pdb_index_to_zero_index(pdb_range[0], protein_chain),
|
||||
|
||||
@@ -196,9 +196,7 @@ class PromptManager:
|
||||
|
||||
def redraw(self, change=None):
|
||||
categories = ["Mask (-)"]
|
||||
color_map = {
|
||||
"Mask (-)": "white",
|
||||
}
|
||||
color_map = {"Mask (-)": "white"}
|
||||
data_array = [0] * self.prompt_length
|
||||
for prompt_str, *_ in self.prompts.items():
|
||||
color, _, _ = self.prompts[prompt_str]
|
||||
@@ -282,7 +280,7 @@ class PromptManager:
|
||||
value=(
|
||||
f'<div style="display: inline-block; width: 10px; height: 10px; background-color:{label_color}; margin-right: 5px;"></div>'
|
||||
f"{range_string}"
|
||||
),
|
||||
)
|
||||
)
|
||||
entry_label.tag = range_string # type: ignore
|
||||
entry_container = widgets.HBox([entry_button, entry_label])
|
||||
|
||||
@@ -9,11 +9,7 @@ from esm.widgets.utils.printing import wrapped_print
|
||||
|
||||
|
||||
class ProteinImporter:
|
||||
def __init__(
|
||||
self,
|
||||
max_proteins: int | None = None,
|
||||
autoload: bool = False,
|
||||
) -> None:
|
||||
def __init__(self, max_proteins: int | None = None, autoload: bool = False) -> None:
|
||||
self._protein_list: list[tuple[str, ProteinChain]] = []
|
||||
self._protein_workspace: dict[str, str] = {}
|
||||
self.max_proteins = max_proteins
|
||||
|
||||
@@ -55,9 +55,7 @@ def create_download_results_button(
|
||||
)
|
||||
|
||||
|
||||
def serialize_protein(
|
||||
protein: ESMProtein,
|
||||
) -> str:
|
||||
def serialize_protein(protein: ESMProtein) -> str:
|
||||
protein_dict = {
|
||||
"sequence": protein.sequence,
|
||||
"coordinates": protein.coordinates.tolist()
|
||||
|
||||
@@ -125,9 +125,7 @@ def create_esm3_generation_launcher(
|
||||
]
|
||||
)
|
||||
|
||||
generation_config_ui = widgets.VBox(
|
||||
[generation_config_settings_ui],
|
||||
)
|
||||
generation_config_ui = widgets.VBox([generation_config_settings_ui])
|
||||
|
||||
def on_track_change(change):
|
||||
if change["new"] == "function":
|
||||
|
||||
@@ -12,9 +12,7 @@ from esm.widgets.components.sequence_prompt_selector import (
|
||||
from esm.widgets.components.structure_prompt_selector import (
|
||||
create_structure_prompt_selector,
|
||||
)
|
||||
from esm.widgets.utils.prompting import (
|
||||
PromptManagerCollection,
|
||||
)
|
||||
from esm.widgets.utils.prompting import PromptManagerCollection
|
||||
from esm.widgets.utils.protein_import import ProteinImporter
|
||||
|
||||
|
||||
|
||||
@@ -2,20 +2,13 @@ from typing import Any, Literal
|
||||
|
||||
from ipywidgets import widgets
|
||||
|
||||
from esm.sdk.api import (
|
||||
ESM3InferenceClient,
|
||||
ESMProtein,
|
||||
)
|
||||
from esm.sdk.api import ESM3InferenceClient, ESMProtein
|
||||
from esm.utils.constants import esm3 as C
|
||||
from esm.widgets.components.function_annotator import (
|
||||
create_function_annotator,
|
||||
)
|
||||
from esm.widgets.utils.prompting import (
|
||||
PromptManagerCollection,
|
||||
)
|
||||
from esm.widgets.utils.protein_import import (
|
||||
ProteinImporter,
|
||||
)
|
||||
from esm.widgets.utils.prompting import PromptManagerCollection
|
||||
from esm.widgets.utils.protein_import import ProteinImporter
|
||||
from esm.widgets.views.esm3_generation_launcher import (
|
||||
create_esm3_generation_launcher,
|
||||
)
|
||||
@@ -47,14 +40,9 @@ def create_generation_ui(
|
||||
protein_length_ui = widgets.VBox(
|
||||
[
|
||||
widgets.HTML(value="<h3>Specify Prompt Length:</h3>"),
|
||||
widgets.HBox(
|
||||
[
|
||||
protein_length_input,
|
||||
protein_length_confirm_button,
|
||||
]
|
||||
),
|
||||
widgets.HBox([protein_length_input, protein_length_confirm_button]),
|
||||
output,
|
||||
],
|
||||
]
|
||||
)
|
||||
loading_ui = widgets.HTML(value="<h3>Loading...</h3>")
|
||||
|
||||
@@ -117,10 +105,7 @@ def create_generation_ui(
|
||||
add_annotation_callback=prompt_manager_collection.add_function_annotation,
|
||||
delete_annotation_callback=prompt_manager_collection.delete_function_annotation,
|
||||
)
|
||||
function_annotator_ui.children = [
|
||||
function_annotator_title,
|
||||
function_annotator,
|
||||
]
|
||||
function_annotator_ui.children = [function_annotator_title, function_annotator]
|
||||
|
||||
if len(protein_importer.protein_list) == 0:
|
||||
prompt_ui.children = [
|
||||
@@ -139,10 +124,7 @@ def create_generation_ui(
|
||||
esm3_selector_ui = create_esm3_prompt_selector(
|
||||
prompt_manager_collection, protein_importer=protein_importer
|
||||
)
|
||||
selector_ui.children = [
|
||||
selector_title,
|
||||
esm3_selector_ui,
|
||||
]
|
||||
selector_ui.children = [selector_title, esm3_selector_ui]
|
||||
prompt_ui.children = [
|
||||
protein_importer_ui,
|
||||
protein_length_ui,
|
||||
@@ -184,10 +166,7 @@ def create_generation_ui(
|
||||
copy_to_prompt_callback=copy_to_prompt_callback,
|
||||
)
|
||||
generation_launcher_ui = widgets.VBox(
|
||||
[
|
||||
widgets.HTML(value="<h3>Generation Config:</h3>"),
|
||||
generation_launcher,
|
||||
]
|
||||
[widgets.HTML(value="<h3>Generation Config:</h3>"), generation_launcher]
|
||||
)
|
||||
|
||||
if len(protein_importer.protein_list) > 0:
|
||||
|
||||
@@ -10,9 +10,7 @@ from esm.widgets.components.results_visualizer import (
|
||||
create_results_visualizer,
|
||||
)
|
||||
from esm.widgets.utils.printing import wrapped_print
|
||||
from esm.widgets.utils.protein_import import (
|
||||
ProteinImporter,
|
||||
)
|
||||
from esm.widgets.utils.protein_import import ProteinImporter
|
||||
|
||||
|
||||
def create_inverse_folding_ui(client: ESM3InferenceClient) -> widgets.Widget:
|
||||
|
||||
@@ -133,10 +133,7 @@ def create_login_ui(client_container: ClientInitContainer):
|
||||
start_msg_output,
|
||||
]
|
||||
elif change["new"] == "Local":
|
||||
model_selection_ui.children = [
|
||||
model_selection_header,
|
||||
local_model,
|
||||
]
|
||||
model_selection_ui.children = [model_selection_header, local_model]
|
||||
login_ui.children = [
|
||||
infobox,
|
||||
selection_ui,
|
||||
|
||||
@@ -10,9 +10,7 @@ from esm.widgets.components.results_visualizer import (
|
||||
create_results_visualizer,
|
||||
)
|
||||
from esm.widgets.utils.printing import wrapped_print
|
||||
from esm.widgets.utils.protein_import import (
|
||||
ProteinImporter,
|
||||
)
|
||||
from esm.widgets.utils.protein_import import ProteinImporter
|
||||
|
||||
|
||||
def create_prediction_ui(client: ESM3InferenceClient) -> widgets.Widget:
|
||||
@@ -85,11 +83,7 @@ def create_prediction_ui(client: ESM3InferenceClient) -> widgets.Widget:
|
||||
try:
|
||||
# Reset the output and results
|
||||
output.clear_output()
|
||||
prediction_ui.children = [
|
||||
input_ui,
|
||||
predict_button,
|
||||
output,
|
||||
]
|
||||
prediction_ui.children = [input_ui, predict_button, output]
|
||||
# Predict the protein's properties
|
||||
with output:
|
||||
protein = get_protein()
|
||||
@@ -159,19 +153,10 @@ def create_prediction_ui(client: ESM3InferenceClient) -> widgets.Widget:
|
||||
wrapped_print(e)
|
||||
|
||||
predict_button.on_click(on_click_predict)
|
||||
protein_importer.entries_box.observe(
|
||||
on_new_protein,
|
||||
names="children",
|
||||
)
|
||||
protein_importer.entries_box.observe(on_new_protein, names="children")
|
||||
protein_importer.register_delete_callback(lambda: validate_predict(None))
|
||||
|
||||
sequence_input_ui.children[1].observe(
|
||||
on_new_sequence,
|
||||
names="value",
|
||||
)
|
||||
input_ui.observe(
|
||||
validate_predict,
|
||||
names="selected_index",
|
||||
)
|
||||
sequence_input_ui.children[1].observe(on_new_sequence, names="value")
|
||||
input_ui.observe(validate_predict, names="selected_index")
|
||||
|
||||
return prediction_ui
|
||||
|
||||
@@ -7,8 +7,9 @@ from esm.sdk.api import (
|
||||
ESM3InferenceClient,
|
||||
ESMProtein,
|
||||
GenerationConfig,
|
||||
InverseFoldingConfig,
|
||||
)
|
||||
from esm.sdk.forge import FoldForgeInferenceClient
|
||||
from esm.sdk.forge import SequenceStructureForgeInferenceClient
|
||||
|
||||
|
||||
def convert_none_to_nan(data):
|
||||
@@ -21,46 +22,53 @@ def convert_none_to_nan(data):
|
||||
return data
|
||||
|
||||
|
||||
def are_allclose_with_nan(A, B, rtol=1e-5, atol=1e-2):
|
||||
B = convert_none_to_nan(B)
|
||||
|
||||
A = np.array(A)
|
||||
B = np.array(B)
|
||||
|
||||
if A.shape != B.shape:
|
||||
raise ValueError("A and B must have the same shape")
|
||||
|
||||
nan_mask_A = np.isnan(A)
|
||||
nan_mask_B = np.isnan(B)
|
||||
|
||||
if not np.array_equal(nan_mask_A, nan_mask_B):
|
||||
return False
|
||||
|
||||
return np.allclose(A[~nan_mask_A], B[~nan_mask_B], rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
def main(fold_client: FoldForgeInferenceClient, esm3_client: ESM3InferenceClient):
|
||||
# Folding
|
||||
def main(
|
||||
sequence_structure_client: SequenceStructureForgeInferenceClient,
|
||||
esm3_client: ESM3InferenceClient,
|
||||
):
|
||||
# Folding with esm3 client
|
||||
protein = get_sample_protein()
|
||||
sequence_length = len(protein.sequence) # type: ignore
|
||||
num_steps = int(sequence_length / 16)
|
||||
protein.coordinates = None
|
||||
protein.function_annotations = None
|
||||
protein.sasa = None
|
||||
assert protein.sequence is not None, "Protein sequence must be set to fold"
|
||||
# Folding with esm3 client
|
||||
folded_protein = cast(
|
||||
ESMProtein,
|
||||
esm3_client.generate(
|
||||
protein,
|
||||
GenerationConfig(
|
||||
track="structure", schedule="cosine", num_steps=num_steps, temperature=0
|
||||
),
|
||||
),
|
||||
)
|
||||
config = GenerationConfig(track="structure", num_steps=1, temperature=0)
|
||||
esm3_client_folded_protein = esm3_client.generate(protein, config)
|
||||
assert isinstance(
|
||||
esm3_client_folded_protein, ESMProtein
|
||||
), f"Using ESM3 client, ESMProtein was expected but got {protein}"
|
||||
# Folding with folding client
|
||||
coordinates = fold_client.fold(
|
||||
"esm3",
|
||||
protein.sequence, # type:ignore
|
||||
potential_sequence_of_concern=False,
|
||||
sequence_structure_client_folded_protein = sequence_structure_client.fold(
|
||||
protein.sequence, potential_sequence_of_concern=False
|
||||
)
|
||||
assert are_allclose_with_nan(folded_protein.coordinates, coordinates)
|
||||
assert isinstance(
|
||||
sequence_structure_client_folded_protein, ESMProtein
|
||||
), f"Using sequence_structure client, ESMProtein was expected but got {sequence_structure_client_folded_protein}"
|
||||
|
||||
# Inverse Folding with esm3 client
|
||||
protein = get_sample_protein()
|
||||
protein.sequence = None
|
||||
protein.sasa = None
|
||||
protein.function_annotations = None
|
||||
assert (
|
||||
protein.coordinates is not None
|
||||
), "Protein coordinates must be set to inverse fold"
|
||||
config = GenerationConfig("sequence", num_steps=1, temperature=0.7)
|
||||
esm3_client_inv_folded_protein = cast(
|
||||
ESMProtein, esm3_client.generate(protein, config)
|
||||
)
|
||||
assert isinstance(
|
||||
esm3_client_inv_folded_protein, ESMProtein
|
||||
), f"Using ESM3 client, ESMProtein was expected but got {protein}"
|
||||
# Inverse Folding with inverse folding client
|
||||
sequence_structure_client_inv_folded_protein = (
|
||||
sequence_structure_client.inverse_fold(
|
||||
protein.coordinates,
|
||||
config=InverseFoldingConfig(temperature=0.7),
|
||||
potential_sequence_of_concern=False,
|
||||
)
|
||||
)
|
||||
assert isinstance(
|
||||
sequence_structure_client_inv_folded_protein, ESMProtein
|
||||
), f"Using sequence_structure client, ESMProtein was expected but got {sequence_structure_client_inv_folded_protein}"
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,3 +1,5 @@
|
||||
import torch
|
||||
|
||||
from esm.models.esm3 import ESM3
|
||||
from esm.sdk.api import (
|
||||
ESM3InferenceClient,
|
||||
@@ -38,8 +40,7 @@ def main(client: ESM3InferenceClient):
|
||||
protein.function_annotations = None
|
||||
protein = client.encode(protein)
|
||||
single_step_protein = client.forward_and_sample(
|
||||
protein,
|
||||
SamplingConfig(structure=SamplingTrackConfig(topk_logprobs=2)),
|
||||
protein, SamplingConfig(structure=SamplingTrackConfig(topk_logprobs=2))
|
||||
)
|
||||
single_step_protein.protein_tensor.sequence = protein.sequence
|
||||
single_step_protein = client.decode(single_step_protein.protein_tensor)
|
||||
@@ -52,8 +53,7 @@ def main(client: ESM3InferenceClient):
|
||||
)
|
||||
protein = ESMProtein(sequence=prompt)
|
||||
protein = client.generate(
|
||||
protein,
|
||||
GenerationConfig(track="sequence", num_steps=8, temperature=0.7),
|
||||
protein, GenerationConfig(track="sequence", num_steps=8, temperature=0.7)
|
||||
)
|
||||
assert isinstance(protein, ESMProtein), f"ESMProtein was expected but got {protein}"
|
||||
|
||||
@@ -189,5 +189,6 @@ def main(client: ESM3InferenceClient):
|
||||
assert isinstance(p, ESMProtein), f"ESMProtein was expected but got {p}"
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(ESM3.from_pretrained("esm3_sm_open_v1"))
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "esm"
|
||||
version = "3.0.7post1"
|
||||
version = "3.0.8"
|
||||
description = "EvolutionaryScale open model repository"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@@ -53,9 +53,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from esm.widgets.utils.types import ClientInitContainer\n",
|
||||
"from esm.widgets.views.inverse_folding import (\n",
|
||||
" create_inverse_folding_ui,\n",
|
||||
")\n",
|
||||
"from esm.widgets.views.inverse_folding import create_inverse_folding_ui\n",
|
||||
"from esm.widgets.views.login import create_login_ui"
|
||||
]
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user