Sync from internal

This commit is contained in:
Zeming Lin
2024-11-25 22:38:31 +00:00
parent 39a3a6cb1e
commit b49c708a19
62 changed files with 2326 additions and 2501 deletions

View File

@@ -1 +1,2 @@
__version__ = "3.0.7post1"
__version__ = "3.0.8"

View File

@@ -10,11 +10,7 @@ from esm.layers.rotary import RotaryEmbedding
class MultiHeadAttention(nn.Module):
def __init__(
self,
d_model: int,
n_heads: int,
bias: bool = False,
qk_layernorm: bool = True,
self, d_model: int, n_heads: int, bias: bool = False, qk_layernorm: bool = True
):
super().__init__()

View File

@@ -78,11 +78,7 @@ class GeometricReasoningOriginalImpl(nn.Module):
affine.rot[..., None]
.apply(rearrange(vec_rot, "... (h c) -> ... h c", c=3))
.split(
[
self.v_heads,
self.v_heads,
self.v_heads * self.num_vector_messages,
],
[self.v_heads, self.v_heads, self.v_heads * self.num_vector_messages],
dim=-2,
)
)

View File

@@ -2,9 +2,7 @@ import torch.nn as nn
def RegressionHead(
d_model: int,
output_dim: int,
hidden_dim: int | None = None,
d_model: int, output_dim: int, hidden_dim: int | None = None
) -> nn.Module:
"""Single-hidden layer MLP for supervised output.

View File

@@ -1,9 +1,7 @@
import torch
import torch.nn as nn
from esm.utils.constants.physics import (
BB_COORDINATES,
)
from esm.utils.constants.physics import BB_COORDINATES
from esm.utils.structure.affine3d import (
Affine3D,
RotationMatrix,

View File

@@ -29,9 +29,7 @@ from esm.sdk.api import (
ProteinType,
SamplingConfig,
)
from esm.tokenization import (
TokenizerCollectionProtocol,
)
from esm.tokenization import TokenizerCollectionProtocol
from esm.utils import encoding
from esm.utils.constants import esm3 as C
from esm.utils.constants.models import (
@@ -173,11 +171,7 @@ class OutputHeads(nn.Module):
secondary_structure_logits = self.ss8_head(x)
sasa_logits = self.sasa_head(x)
function_logits = self.function_head(x)
function_logits = einops.rearrange(
function_logits,
"... (k v) -> ... k v",
k=8,
)
function_logits = einops.rearrange(function_logits, "... (k v) -> ... k v", k=8)
residue_logits = self.residue_head(x)
@@ -217,11 +211,7 @@ class ESM3(nn.Module, ESM3InferenceClient):
super().__init__()
self.encoder = EncodeInputs(d_model)
self.transformer = TransformerStack(
d_model,
n_heads,
v_heads,
n_layers,
mask_and_zero_frameless=True,
d_model, n_heads, v_heads, n_layers, mask_and_zero_frameless=True
)
self.output_heads = OutputHeads(d_model)
@@ -237,9 +227,7 @@ class ESM3(nn.Module, ESM3InferenceClient):
@classmethod
def from_pretrained(
cls,
model_name: str = ESM3_OPEN_SMALL,
device: torch.device | None = None,
cls, model_name: str = ESM3_OPEN_SMALL, device: torch.device | None = None
) -> ESM3:
from esm.pretrained import load_local_model
@@ -489,15 +477,14 @@ class ESM3(nn.Module, ESM3InferenceClient):
reference_sequence = encoding.get_default_sequence(sequence_length - 2)
else:
reference_sequence = input.sequence
(
function_tokens,
residue_annotation_tokens,
) = encoding.tokenize_function_annotations(
input.function_annotations,
reference_sequence=reference_sequence,
function_tokenizer=self.tokenizers.function,
residue_annotation_tokenizer=self.tokenizers.residue_annotations,
add_special_tokens=True,
(function_tokens, residue_annotation_tokens) = (
encoding.tokenize_function_annotations(
input.function_annotations,
reference_sequence=reference_sequence,
function_tokenizer=self.tokenizers.function,
residue_annotation_tokenizer=self.tokenizers.residue_annotations,
add_special_tokens=True,
)
)
return ESMProteinTensor(
@@ -510,10 +497,7 @@ class ESM3(nn.Module, ESM3InferenceClient):
coordinates=coordinates,
).to(next(self.parameters()).device)
def decode(
self,
input: ESMProteinTensor,
) -> ESMProtein:
def decode(self, input: ESMProteinTensor) -> ESMProtein:
return decode_protein_tensor(
input=input,
tokenizers=self.tokenizers,
@@ -613,10 +597,7 @@ class ESM3(nn.Module, ESM3InferenceClient):
logits_output: LogitsOutput = _batch_forward(self, batched_protein)
forward_and_sample_out: ForwardAndSampleOutput = _sample_per_prompt(
batched_protein,
logits_output,
sampling_config,
self.tokenizers,
batched_protein, logits_output, sampling_config, self.tokenizers
)
# There is only 1 prompt to sample for.

View File

@@ -167,8 +167,7 @@ class FunctionTokenDecoder(nn.Module):
# Apply depth-position offset to use distinct vocabs. See __init__ for
# explaination.
vocab_offsets = self.config.function_token_vocab_size * torch.arange(
self.config.function_token_depth,
device=token_ids.device,
self.config.function_token_depth, device=token_ids.device
)
inputs = token_ids + vocab_offsets[None, :]
@@ -251,8 +250,7 @@ class FunctionTokenDecoder(nn.Module):
annotations.append(annotation)
annotations = merge_annotations(
annotations,
merge_gap_max=annotation_gap_merge_max,
annotations, merge_gap_max=annotation_gap_merge_max
)
# Drop very small annotations.

View File

@@ -87,10 +87,7 @@ class PairwisePredictionHead(nn.Module):
prod = q[:, None, :, :] * k[:, :, None, :]
diff = q[:, None, :, :] - k[:, :, None, :]
x_2d = [
prod,
diff,
]
x_2d = [prod, diff]
if pairwise is not None:
x_2d.append(pairwise)
x = torch.cat(x_2d, dim=-1)
@@ -289,11 +286,7 @@ class StructureTokenEncoder(nn.Module):
with torch.no_grad(), torch.cuda.amp.autocast(enabled=False): # type: ignore
ca = coords[..., 1, :]
edges, edge_mask = knn_graph(
ca,
coord_mask,
padding_mask,
sequence_id,
no_knn=knn,
ca, coord_mask, padding_mask, sequence_id, no_knn=knn
)
return edges, edge_mask
@@ -333,12 +326,7 @@ class StructureTokenEncoder(nn.Module):
class StructureTokenDecoder(nn.Module):
def __init__(
self,
d_model,
n_heads,
n_layers,
):
def __init__(self, d_model, n_heads, n_layers):
super().__init__()
self.decoder_channels = d_model

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from abc import ABC
from typing import Sequence
from typing import List, Sequence
import attr
import torch
@@ -19,14 +19,10 @@ from esm.utils.misc import (
)
from esm.utils.structure.protein_chain import ProteinChain
from esm.utils.structure.protein_complex import ProteinComplex
from esm.utils.types import (
FunctionAnnotation,
PathOrBuffer,
)
from esm.utils.types import FunctionAnnotation, PathOrBuffer
class ProteinType(ABC):
...
class ProteinType(ABC): ...
## Basic Types
@@ -184,6 +180,9 @@ class ESMProteinTensor(ProteinType):
# Such sequences may not go through standard safety filter for approved users.
# Reach out if interested in using this.
potential_sequence_of_concern: bool = False
# Control vectors are vectors added to each layer of the model to nudge hidden states to the desired direction.
# len(control_vectors) == number of blocks in the model. Each vector in the list have the shape of (batch size, sequence length, hidden dim)
# so it can be added to the corresponding layer in the model
def _detect_attribute(self, func, msg):
mapped = {
@@ -260,20 +259,40 @@ class ESMProteinError(Exception, ProteinType):
class GenerationConfig:
track: str = ""
invalid_ids: Sequence[int] = []
schedule: str = "cosine"
# Controls the number of tokens to unmask during each round of iterative generation.
schedule: str = attr.field(
validator=attr.validators.in_(["cosine", "linear"]), default="cosine"
)
# Controls which tokens to unmask during each round of iterative generation.
# "random" will unmask a correct number of tokens randomly.
# "entropy" will unmask the tokens with the lowest logit entropy first.
strategy: str = attr.field(
validator=attr.validators.in_(["random", "entropy"]), default="entropy"
)
# Set this to a higher value for better generation results.
# Note that this needs to be less than or equal to the sequence length.
num_steps: int = 1
temperature: float = 1.0
temperature_annealing: bool = False
top_p: float = 1.0
condition_on_coordinates_only: bool = True
def use_entropy_based_unmasking_strategy(self):
"""Use entropy based unmasking strategy during generation."""
self.schedule = "cosine"
self.strategy = "entropy"
self.temperature_annealing = False
def use_generative_unmasking_strategy(self):
"""Use an unmasking strategy that produces more variety of generations."""
self.schedule = "cosine"
self.strategy = "random"
self.temperature_annealing = True
@define
class InverseFoldingConfig:
invalid_ids: Sequence[int] = []
schedule: str = "cosine"
num_steps: int = 1
temperature: float = 1.0
@@ -370,9 +389,7 @@ class ESM3InferenceClient(ABC):
raise NotImplementedError
def batch_generate(
self,
inputs: Sequence[ProteinType],
configs: Sequence[GenerationConfig],
self, inputs: Sequence[ProteinType], configs: Sequence[GenerationConfig]
) -> Sequence[ProteinType]:
# Same as generate(...), but generates a batch of proteins at once.
raise NotImplementedError

View File

@@ -5,12 +5,7 @@ from urllib.parse import urljoin
import requests
import torch
from tenacity import (
retry,
retry_if_result,
stop_after_attempt,
wait_exponential,
)
from tenacity import retry, retry_if_result, stop_after_attempt, wait_exponential
from esm.sdk.api import (
ESM3InferenceClient,
@@ -20,6 +15,7 @@ from esm.sdk.api import (
ForwardAndSampleOutput,
ForwardTrackData,
GenerationConfig,
InverseFoldingConfig,
LogitsConfig,
LogitsOutput,
ProteinType,
@@ -55,7 +51,15 @@ def log_retry_attempt(retry_state):
)
class FoldForgeInferenceClient:
def _validate_protein_tensor_input(input):
if not isinstance(input, ESMProteinTensor):
raise ValueError(
"Input must be an ESMProteinTensor instance. "
"Use encode() API to encode an ESMProtein into ESMProteinTensor."
)
class SequenceStructureForgeInferenceClient:
def __init__(
self,
url: str = "https://forge.evolutionaryscale.ai",
@@ -73,31 +77,51 @@ class FoldForgeInferenceClient:
def fold(
self,
model_name: str,
sequence: str,
potential_sequence_of_concern: bool,
) -> torch.Tensor | ESMProteinError:
request = {
"model": model_name,
"sequence": sequence,
}
model_name: str | None = None,
) -> ESMProtein | ESMProteinError:
request = {"sequence": sequence}
if model_name is not None:
request["model"] = model_name
try:
data = self._post(
"fold",
request,
potential_sequence_of_concern,
)
data = self._post("fold", request, potential_sequence_of_concern)
except ESMProteinError as e:
return e
return data["coordinates"]
return ESMProtein(
coordinates=maybe_tensor(data["coordinates"], convert_none_to_nan=True)
)
def inverse_fold(
self,
coordinates: torch.Tensor,
config: InverseFoldingConfig,
potential_sequence_of_concern: bool,
model_name: str | None = None,
) -> ESMProtein | ESMProteinError:
inverse_folding_config = {
"invalid_ids": config.invalid_ids,
"temperature": config.temperature,
}
request = {
"coordinates": maybe_list(coordinates, convert_nan_to_none=True),
"inverse_folding_config": inverse_folding_config,
}
if model_name is not None:
request["model"] = model_name
try:
data = self._post("inverse_fold", request, potential_sequence_of_concern)
except ESMProteinError as e:
return e
return ESMProtein(sequence=data["sequence"])
def _post(self, endpoint, request, potential_sequence_of_concern):
request["potential_sequence_of_concern"] = potential_sequence_of_concern
model_name_url = request["model"] if request["model"] != "esm3" else "api"
response = requests.post(
urljoin(self.url, f"/{model_name_url}/v1/{endpoint}"),
urljoin(self.url, f"/api/v1/{endpoint}"),
json=request,
headers=self.headers,
timeout=self.request_timeout,
@@ -115,6 +139,11 @@ class FoldForgeInferenceClient:
if "outputs" not in data and "data" in data:
data = data["data"]
# Print warning message if there is any.
if "warning_messages" in data and data["warning_messages"] is not None:
for msg in data["warning_messages"]:
print("\033[31m", msg, "\033[0m")
return data
@@ -174,18 +203,13 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
output = self.__generate_protein_tensor(input, config)
else:
return ESMProteinError(
error_code=500,
error_msg=f"Unknown input type {type(input)}",
error_code=500, error_msg=f"Unknown input type {type(input)}"
)
if (
isinstance(output, ESMProtein)
and isinstance(input, ESMProtein)
and config.track
not in [
"function",
"residue_annotations",
]
and config.track not in ["function", "residue_annotations"]
):
# Function and residue annotation encoding/decoding is lossy
# There is no guarantee that decoding encoded tokens will yield the same input
@@ -218,9 +242,7 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
return [_capture_exception(r) for r in results]
def __generate_protein(
self,
input: ESMProtein,
config: GenerationConfig,
self, input: ESMProtein, config: GenerationConfig
) -> ESMProtein | ESMProteinError:
req = {}
req["sequence"] = input.sequence
@@ -261,9 +283,7 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
)
def __generate_protein_tensor(
self,
input: ESMProteinTensor,
config: GenerationConfig,
self, input: ESMProteinTensor, config: GenerationConfig
) -> ESMProteinTensor | ESMProteinError:
req = {}
req["sequence"] = maybe_list(input.sequence)
@@ -316,6 +336,7 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
def forward_and_sample(
self, input: ESMProteinTensor, sampling_configuration: SamplingConfig
) -> ForwardAndSampleOutput | ESMProteinError:
_validate_protein_tensor_input(input)
validate_sampling_config(sampling_configuration, on_invalid="raise")
req = {}
@@ -441,10 +462,9 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
)
@retry_decorator
def decode(
self,
input: ESMProteinTensor,
) -> ESMProtein | ESMProteinError:
def decode(self, input: ESMProteinTensor) -> ESMProtein | ESMProteinError:
_validate_protein_tensor_input(input)
tokens = {}
tokens["sequence"] = maybe_list(input.sequence)
tokens["structure"] = maybe_list(input.structure)
@@ -454,10 +474,7 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
tokens["residue_annotation"] = maybe_list(input.residue_annotations)
tokens["coordinates"] = maybe_list(input.coordinates, convert_nan_to_none=True)
request = {
"model": self.model,
"inputs": tokens,
}
request = {"model": self.model, "inputs": tokens}
try:
data = self._post("decode", request, input.potential_sequence_of_concern)
@@ -482,6 +499,8 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
def logits(
self, input: ESMProteinTensor, config: LogitsConfig = LogitsConfig()
) -> LogitsOutput | ESMProteinError:
_validate_protein_tensor_input(input)
# Note: using raw model forwards is discouraged because of the byte size
# of the logits.
# Please use forward_and_sample instead.
@@ -504,11 +523,7 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient):
"return_embeddings": config.return_embeddings,
}
request = {
"model": self.model,
"inputs": req,
"logits_config": logits_config,
}
request = {"model": self.model, "inputs": req, "logits_config": logits_config}
try:
data = self._post("logits", request, input.potential_sequence_of_concern)

View File

@@ -2,7 +2,61 @@ import json
import boto3
from esm.sdk.forge import ESM3ForgeInferenceClient
from esm.sdk.forge import (
ESM3ForgeInferenceClient,
SequenceStructureForgeInferenceClient,
)
class SequenceStructureSageMakerClient(SequenceStructureForgeInferenceClient):
def __init__(self, endpoint_name: str):
"""SequenceStructure (folding and inverse folding) client that talks to a SageMaker endpoint.
Args:
endpoint_name: Name of the SageMaker endpoint.
"""
# Dummy URL and token to make SequenceStructureForgeInferenceClient happy.
super().__init__(url="", token="dummy")
self._endpoint_name = endpoint_name
self._client = boto3.client(service_name="sagemaker-runtime")
def _post(self, endpoint, request, potential_sequence_of_concern):
request["potential_sequence_of_concern"] = potential_sequence_of_concern
request["model"] = request.get("model", None)
invocations_request = {
# Duplicate these fields at the top level to make Forge requests consistent.
"model": request["model"],
"request_id": "", # Forge specific field.
"user_id": "", # Forge specific field.
# Invocation data bits.
"api_ver": "v1", # Must be v1 right now.
"endpoint": endpoint,
# Wrapped request.
endpoint: request,
}
try:
response = self._client.invoke_endpoint(
EndpointName=self._endpoint_name,
ContentType="application/json",
Body=json.dumps(invocations_request),
)
except Exception as e:
raise RuntimeError(f"Failure in {endpoint}: {e}") from e
data = json.loads(response["Body"].read().decode())
# Response must match request.
assert (
data["endpoint"] == endpoint
), f"Response endpoint is {data['endpoint']} but request is {endpoint}"
# Get the actual responses under the endpoint key.
data = data[endpoint]
return data
class ESM3SageMakerClient(ESM3ForgeInferenceClient):

View File

@@ -120,8 +120,7 @@ class InterProQuantizedTokenizer(EsmTokenizerBase):
def _tfidf(self) -> tfidf.TFIDFModel:
"""Creates TF-IDF model for encoding function keywords."""
return tfidf.TFIDFModel(
vocabulary_path=self.keyword_vocabulary_path,
idf_path=self.keyword_idf_path,
vocabulary_path=self.keyword_vocabulary_path, idf_path=self.keyword_idf_path
)
@cached_property
@@ -205,9 +204,7 @@ class InterProQuantizedTokenizer(EsmTokenizerBase):
return tokens
def _function_text_hash(
self,
labels: Collection[str],
keyword_mask: np.ndarray | None = None,
self, labels: Collection[str], keyword_mask: np.ndarray | None = None
) -> np.ndarray | None:
"""Applies a locality sensitive hash (LSH) to function text.
@@ -295,9 +292,7 @@ class InterProQuantizedTokenizer(EsmTokenizerBase):
raise ValueError(f"Unknown token: {token}")
def batch_encode(
self,
token_batch: list[list[str]],
add_special_tokens: bool = True,
self, token_batch: list[list[str]], add_special_tokens: bool = True
) -> torch.Tensor:
"""Encodes batch of function tokens.
@@ -312,8 +307,7 @@ class InterProQuantizedTokenizer(EsmTokenizerBase):
for tokens in token_batch
]
return stack_variable_length_tensors(
encoded,
constant_value=self.vocab_to_index["<pad>"],
encoded, constant_value=self.vocab_to_index["<pad>"]
)
def decode(self, encoded: torch.Tensor):

View File

@@ -13,11 +13,7 @@ Sample = dict[str, Any]
class ResidueAnnotationsTokenizer(EsmTokenizerBase):
def __init__(
self,
csv_path: str | None = None,
max_annotations: int = 16,
):
def __init__(self, csv_path: str | None = None, max_annotations: int = 16):
if csv_path is None:
csv_path = str(C.data_root() / C.RESID_CSV)
self.csv_path = csv_path

View File

@@ -31,7 +31,7 @@ class SASADiscretizingTokenizer(EsmTokenizerBase):
return self.special_tokens + range_tokens
@cached_property
def midpoints(self) -> list[float]:
def midpoints_tensor(self) -> torch.Tensor:
"""Midpoints of the SASA token ranges."""
boundaries = [0] + self._boundaries + [self._boundaries[-1] * 2]
midpoint_tokens = [
@@ -39,7 +39,11 @@ class SASADiscretizingTokenizer(EsmTokenizerBase):
for low, high in zip(boundaries[:-1], boundaries[1:])
]
midpoint_tokens = [float("nan"), float("nan"), float("nan")] + midpoint_tokens
return midpoint_tokens
return torch.Tensor(midpoint_tokens)
def midpoints(self) -> list[float]:
"""Midpoints of the SASA token ranges."""
return self.midpoints_tensor.tolist()
@cached_property
def vocab_to_index(self) -> dict[str, int]:
@@ -86,7 +90,11 @@ class SASADiscretizingTokenizer(EsmTokenizerBase):
def decode_float(self, encoded: torch.Tensor) -> list[float]:
"""Decodes SASA token ids into float values."""
return [self.midpoints[token_id] for token_id in encoded]
decoded = self.midpoints_tensor[encoded.cpu()]
nan_mask = torch.isnan(decoded)
np_arr = decoded.numpy()
np_arr[nan_mask.numpy()] = None
return np_arr.tolist()
def decode(self, encoded: torch.Tensor) -> str:
"""Decodes SASA token ids."""

View File

@@ -40,9 +40,7 @@ class EsmSequenceTokenizer(PreTrainedTokenizerFast, EsmTokenizerBase):
self.cb_token = chain_break_token
additional_special_tokens = [chain_break_token]
tokenizer.add_special_tokens(
special_tokens,
)
tokenizer.add_special_tokens(special_tokens)
# This is where we configure the automatic addition of special tokens when we call
# tokenizer(text, add_special_tokens=True). Note that you can also configure how two

View File

@@ -3,56 +3,42 @@ from typing import Protocol, runtime_checkable
@runtime_checkable
class EsmTokenizerBase(Protocol):
def encode(self, *args, **kwargs):
...
def encode(self, *args, **kwargs): ...
def decode(self, *args, **kwargs):
...
def decode(self, *args, **kwargs): ...
@property
def mask_token(self) -> str:
...
def mask_token(self) -> str: ...
@property
def mask_token_id(self) -> int:
...
def mask_token_id(self) -> int: ...
@property
def bos_token(self) -> str:
...
def bos_token(self) -> str: ...
@property
def bos_token_id(self) -> int:
...
def bos_token_id(self) -> int: ...
@property
def eos_token(self) -> str:
...
def eos_token(self) -> str: ...
@property
def eos_token_id(self) -> int:
...
def eos_token_id(self) -> int: ...
@property
def pad_token(self) -> str:
...
def pad_token(self) -> str: ...
@property
def pad_token_id(self) -> int:
...
def pad_token_id(self) -> int: ...
@property
def chain_break_token(self) -> str:
...
def chain_break_token(self) -> str: ...
@property
def chain_break_token_id(self) -> int:
...
def chain_break_token_id(self) -> int: ...
@property
def all_token_ids(self):
...
def all_token_ids(self): ...
@property
def special_token_ids(self):
...
def special_token_ids(self): ...

View File

@@ -112,9 +112,7 @@ INTERPRO_HIERARCHY = IN_REPO_DATA_FOLDER / "ParentChildTreeFile.txt"
INTERPRO2GO = IN_REPO_DATA_FOLDER / "ParentChildTreeFile.txt"
INTERPRO_2ID = "data/tag_dict_4_safety_filtered.json"
LSH_TABLE_PATHS = {
"8bit": "data/hyperplanes_8bit_58641.npz",
}
LSH_TABLE_PATHS = {"8bit": "data/hyperplanes_8bit_58641.npz"}
KEYWORDS_VOCABULARY = (
IN_REPO_DATA_FOLDER / "keyword_vocabulary_safety_filtered_58641.txt"

View File

@@ -1,4 +1,5 @@
import warnings
from typing import cast
import attr
import torch
@@ -31,7 +32,7 @@ from esm.utils.function.encode_decode import (
decode_function_tokens,
decode_residue_annotation_tokens,
)
from esm.utils.misc import list_nan_to_none
from esm.utils.misc import maybe_list
from esm.utils.structure.protein_chain import ProteinChain
from esm.utils.types import FunctionAnnotation
@@ -130,15 +131,10 @@ def _bos_eos_warn(msg: str, tensor: torch.Tensor, tok: EsmTokenizerBase):
def decode_sequence(
sequence_tokens: torch.Tensor,
sequence_tokenizer: EsmSequenceTokenizer,
**kwargs,
sequence_tokens: torch.Tensor, sequence_tokenizer: EsmSequenceTokenizer, **kwargs
) -> str:
_bos_eos_warn("Sequence", sequence_tokens, sequence_tokenizer)
sequence = sequence_tokenizer.decode(
sequence_tokens,
**kwargs,
)
sequence = sequence_tokenizer.decode(sequence_tokens, **kwargs)
sequence = sequence.replace(" ", "")
sequence = sequence.replace(sequence_tokenizer.mask_token, C.MASK_STR_SHORT)
sequence = sequence.replace(sequence_tokenizer.cls_token, "")
@@ -185,20 +181,16 @@ def decode_structure(
def decode_secondary_structure(
secondary_structure_tokens: torch.Tensor,
ss_tokenizer: SecondaryStructureTokenizer,
secondary_structure_tokens: torch.Tensor, ss_tokenizer: SecondaryStructureTokenizer
) -> str:
_bos_eos_warn("Secondary structure", secondary_structure_tokens, ss_tokenizer)
secondary_structure_tokens = secondary_structure_tokens[1:-1]
secondary_structure = ss_tokenizer.decode(
secondary_structure_tokens,
)
secondary_structure = ss_tokenizer.decode(secondary_structure_tokens)
return secondary_structure
def decode_sasa(
sasa_tokens: torch.Tensor,
sasa_tokenizer: SASADiscretizingTokenizer,
sasa_tokens: torch.Tensor, sasa_tokenizer: SASADiscretizingTokenizer
) -> list[float]:
if sasa_tokens[0] != 0:
raise ValueError("SASA does not start with 0 corresponding to BOS token")
@@ -213,12 +205,13 @@ def decode_sasa(
torch.long,
]:
# Decode if int
# handles turning NaN's into None's
sasa = sasa_tokenizer.decode_float(sasa_tokens)
else:
# If already float, just convert to list
sasa = sasa_tokens.tolist()
sasa = cast(list[float], maybe_list(sasa_tokens, convert_nan_to_none=True))
return list_nan_to_none(sasa)
return sasa
def decode_function_annotations(

View File

@@ -97,9 +97,7 @@ def tokenize_structure(
# Add space for BOS and EOS tokens
if add_special_tokens:
coordinates = F.pad(
coordinates,
(0, 0, 0, 0, left_pad, right_pad),
value=torch.inf,
coordinates, (0, 0, 0, 0, left_pad, right_pad), value=torch.inf
)
plddt = F.pad(plddt, (left_pad, right_pad), value=0)
structure_tokens = F.pad(
@@ -171,8 +169,7 @@ def tokenize_function_annotations(
# Tokenized Defaults
def get_default_sequence_tokens(
sequence_length: int,
sequence_tokenizer: EsmSequenceTokenizer,
sequence_length: int, sequence_tokenizer: EsmSequenceTokenizer
) -> torch.Tensor:
assert sequence_tokenizer.mask_token_id is not None
assert sequence_tokenizer.bos_token_id is not None
@@ -191,10 +188,7 @@ def get_default_structure_tokens(
sequence_length: int, structure_tokenizer: StructureTokenizer
) -> torch.Tensor:
structure_tokens = (
torch.ones(
(sequence_length + 2,),
dtype=torch.int64,
)
torch.ones((sequence_length + 2,), dtype=torch.int64)
* structure_tokenizer.mask_token_id
)
# Always include BOS and EOS tokens
@@ -241,10 +235,7 @@ def get_default_residue_annotation_tokens(
sequence_length: int, residue_annotation_tokenizer: ResidueAnnotationsTokenizer
) -> torch.Tensor:
residue_annotation_tokens = (
torch.ones(
(sequence_length + 2, C.MAX_RESIDUE_ANNOTATIONS),
dtype=torch.int64,
)
torch.ones((sequence_length + 2, C.MAX_RESIDUE_ANNOTATIONS), dtype=torch.int64)
* residue_annotation_tokenizer.pad_token_id
)
# Always include BOS and EOS tokens

View File

@@ -59,8 +59,7 @@ def encode_function_annotations(
# Convert function token FunctionAnnotations -> Tensor
function_tokens = function_tokens_tokenizer.tokenize(
annotations=ft_annotations,
seqlen=len(sequence),
annotations=ft_annotations, seqlen=len(sequence)
)
function_token_ids = function_tokens_tokenizer.encode(
function_tokens, add_special_tokens=add_special_tokens
@@ -175,10 +174,7 @@ def decode_residue_annotation_tokens(
annotation = FunctionAnnotation(label=label, start=loc, end=loc)
annotations.append(annotation)
annotations = merge_annotations(
annotations,
merge_gap_max=annotation_gap_merge_max,
)
annotations = merge_annotations(annotations, merge_gap_max=annotation_gap_merge_max)
# Drop very small annotations.
if annotation_min_length is not None:

View File

@@ -127,11 +127,7 @@ class InterPro:
col in df.columns for col in ["ENTRY_AC", "ENTRY_TYPE", "ENTRY_NAME"]
)
df.rename(
columns={
"ENTRY_AC": "id",
"ENTRY_TYPE": "type",
"ENTRY_NAME": "name",
},
columns={"ENTRY_AC": "id", "ENTRY_TYPE": "type", "ENTRY_NAME": "name"},
inplace=True,
)
df["type"] = df.type.str.upper().apply(

View File

@@ -50,8 +50,7 @@ class TFIDFModel:
values /= np.linalg.norm(values)
return sparse.csr_matrix(
(values, (np.zeros_like(indices), indices)),
shape=(1, len(self.vocabulary)),
(values, (np.zeros_like(indices), indices)), shape=(1, len(self.vocabulary))
)
def decode(self, vec: sparse.csr_matrix) -> list[str]:

View File

@@ -128,9 +128,7 @@ def iterative_sampling_raw(
def _make_masked_inputs(
track: str,
sequence_length: int,
tokenizers: TokenizerCollectionProtocol,
track: str, sequence_length: int, tokenizers: TokenizerCollectionProtocol
):
get_tokenizer: Callable[[str], EsmTokenizerBase] = lambda s: getattr(tokenizers, s)
@@ -190,8 +188,7 @@ def _stack_protein_tensors(
o,
fn,
stack_variable_length_tensors(
sequences=tensors,
constant_value=mask_token_id,
sequences=tensors, constant_value=mask_token_id
),
)
@@ -240,6 +237,11 @@ def _get_iterative_sampling_mask_for_prompt_and_step(
shape = tokens.shape
B, L = shape[0], shape[1]
# TODO: figure out why we want this function to work with
# _BatchedESMProteinTensor in the first place. Logics below
# don't really work for batched tensors.
assert B == 1
sampling_mask = torch.ones((B, L), dtype=torch.bool, device=device)
sampling_mask[:, 0] = False # BOS
# EOS and all padding tokens.
@@ -248,9 +250,7 @@ def _get_iterative_sampling_mask_for_prompt_and_step(
).to(device)
is_mask = _get_masked_positions(
track_to_sample,
tokens,
getattr(tokenizers, track_to_sample).mask_token_id,
track_to_sample, tokens, getattr(tokenizers, track_to_sample).mask_token_id
)
if not is_mask.any().item():
raise ValueError(f"Cannot sample {config.track} when input has no masks.")
@@ -273,27 +273,36 @@ def _get_iterative_sampling_mask_for_prompt_and_step(
).int()
num_to_sample = still_masked - num_tokens_masked_after_this_step
track_entropy: torch.Tensor = getattr(entropy, track_to_sample).to(
device
) # (B, L) or (B, L, D)
if config.strategy == "entropy":
track_entropy: torch.Tensor = getattr(entropy, track_to_sample).to(
device
) # (B, L) or (B, L, D)
if track_to_sample == "function":
track_entropy = track_entropy.sum(-1) # (B, L, D) -> (B, L)
if track_to_sample == "function":
track_entropy = track_entropy.sum(-1) # (B, L, D) -> (B, L)
track_entropy = track_entropy.masked_fill(
~sampling_mask, torch.finfo(track_entropy.dtype).max
)
_, indices = track_entropy.topk(num_to_sample, dim=-1, largest=False)
is_top_k = torch.zeros((B, L), dtype=torch.bool, device=device).scatter(
1, indices, True
)
where_to_sample = sampling_mask & is_top_k
track_entropy = track_entropy.masked_fill(
~sampling_mask, torch.finfo(track_entropy.dtype).max
)
_, indices = track_entropy.topk(num_to_sample, dim=-1, largest=False)
is_top_k = torch.zeros((B, L), dtype=torch.bool, device=device).scatter(
1, indices, True
)
where_to_sample = sampling_mask & is_top_k
elif config.strategy == "random":
# Skip B since we know there is only 1 prompt here.
_, masked_indices = sampling_mask.nonzero(as_tuple=True)
# Random shuffle the masked indices then select the first num_to_sample.
rnd_indices = masked_indices[torch.randperm(len(masked_indices))][
:num_to_sample
]
rnd_mask = torch.zeros_like(sampling_mask)
rnd_mask[:, rnd_indices] = True
where_to_sample = sampling_mask & rnd_mask
if track_to_sample == "function":
where_to_sample = where_to_sample.unsqueeze(-1).expand(
B,
L,
tokenizers.function.depth,
B, L, tokenizers.function.depth
) # (B, L) -> (B, L, D)
return where_to_sample
@@ -316,6 +325,11 @@ def _get_non_special_tokens(
return int(torch.sum(mask).item())
def _get_annealed_temperature(step: int, num_steps: int, initial_temperature: float):
step_ratio = step / max(1, (num_steps - 1))
return max(initial_temperature - step_ratio, 0.001) ** 2
def iterative_sampling_tokens(
client: ESM3InferenceClient,
input_tokens: list[ESMProteinTensor],
@@ -345,9 +359,7 @@ def iterative_sampling_tokens(
num_sampling_steps = _get_non_special_tokens(protein, tokenizers)
else:
masked = _get_masked_positions(
track,
getattr(protein, track),
getattr(tokenizers, track).mask_token_id,
track, getattr(protein, track), getattr(tokenizers, track).mask_token_id
)
num_sampling_steps = torch.sum(masked).item()
@@ -365,10 +377,7 @@ def iterative_sampling_tokens(
# Now stack the list to make a single batched ESMProteinTensor.
batched_tokens = _stack_protein_tensors(
sampled_tokens,
sequence_lengths,
tokenizers,
devices.pop(),
sampled_tokens, sequence_lengths, tokenizers, devices.pop()
)
# Remember sampled prompts that has somehow errored out.
@@ -418,9 +427,18 @@ def iterative_sampling_tokens(
len(per_prompt_cur_sampled),
)
# Handle temperature annealing, since _sample_per_prompt() doesn't have
# the concept of decoding steps.
if config.temperature_annealing:
temperature = _get_annealed_temperature(
t, config.num_steps, config.temperature
)
else:
temperature = config.temperature
track_sample_config = SamplingTrackConfig()
track_sample_config.invalid_ids = config.invalid_ids
track_sample_config.temperature = config.temperature
track_sample_config.temperature = temperature
track_sample_config.top_p = config.top_p
sampling_config = SamplingConfig(**{config.track: track_sample_config}) # type: ignore
@@ -486,7 +504,7 @@ def iterative_sampling_tokens(
setattr(outputs, "coordinates", getattr(inputs, "coordinates"))
# Maybe restore all the other fields.
for f in attr.fields(SamplingConfig):
if "embedding" in f.name:
if "embedding" in f.name or f.name == "return_hidden_states":
continue
if f.name != config.track:
setattr(outputs, f.name, getattr(inputs, f.name))
@@ -494,10 +512,7 @@ def iterative_sampling_tokens(
return output_tokens
def _batch_forward(
client: ESM3InferenceClient,
protein: _BatchedESMProteinTensor,
):
def _batch_forward(client: ESM3InferenceClient, protein: _BatchedESMProteinTensor):
# Forward pass
return client.logits(
protein,

View File

@@ -97,9 +97,7 @@ def test_num_decoding_steps_more_than_mask_tokens_batched(esm3_remote_inference_
@pytest.mark.gpu
def test_encode_chainbreak_token(esm3_remote_inference_client):
protein = esm3_remote_inference_client.encode(
ESMProtein(sequence="MSTNP|KPQKK"),
)
protein = esm3_remote_inference_client.encode(ESMProtein(sequence="MSTNP|KPQKK"))
assert isinstance(protein, ESMProteinTensor)
assert protein.sequence is not None
assert (

View File

@@ -1,4 +1,3 @@
import math
import os
from collections import defaultdict
from typing import ContextManager, Sequence, TypeVar
@@ -226,8 +225,7 @@ def merge_ranges(ranges: list[range], merge_gap_max: int | None = None) -> list[
def merge_annotations(
annotations: list[FunctionAnnotation],
merge_gap_max: int | None = None,
annotations: list[FunctionAnnotation], merge_gap_max: int | None = None
) -> list[FunctionAnnotation]:
"""Merges annotations into non-overlapping segments.
@@ -256,42 +254,24 @@ def merge_annotations(
return merged
def list_nan_to_none(l: list) -> list:
if l is None:
return None # type: ignore
elif isinstance(l, float):
return None if math.isnan(l) else l # type: ignore
elif isinstance(l, list):
return [list_nan_to_none(x) for x in l]
else:
# Don't go into other structures.
return l
def list_none_to_nan(l: list) -> list:
if l is None:
return math.nan # type: ignore
elif isinstance(l, list):
return [list_none_to_nan(x) for x in l]
else:
return l
def maybe_tensor(x, convert_none_to_nan: bool = False) -> torch.Tensor | None:
if x is None:
return None
if convert_none_to_nan:
x = list_none_to_nan(x)
x = np.array(x, copy=False, dtype=np.float32)
x = np.where(x is None, np.nan, x)
return torch.tensor(x)
def maybe_list(x, convert_nan_to_none: bool = False) -> list | None:
if x is None:
return None
x = x.tolist()
if convert_nan_to_none:
x = list_nan_to_none(x)
return x
if not convert_nan_to_none:
return x.tolist()
nan_mask = torch.isnan(x)
np_arr = x.cpu().numpy().astype(object)
np_arr[nan_mask.cpu().numpy()] = None
return np_arr.tolist()
def huggingfacehub_login():

View File

@@ -1,6 +1,5 @@
"""Tests for misc.py"""
from esm.utils.misc import merge_annotations
from esm.utils.types import FunctionAnnotation

View File

@@ -22,16 +22,30 @@ from esm.utils.constants.esm3 import (
SASA_DISCRETIZATION_BOUNDARIES,
)
# Number of dimensions for each protein tensor field without the batch dimension.
_DIMS: dict[str, int] = {
"sequence": 1,
"structure": 1,
"secondary_structure": 1,
"sasa": 1,
"function": 2,
"residue_annotations": 2,
"coordinates": 3,
}
def _non_batched_dims(k: str, v: torch.Tensor):
match k:
case "sequence":
return 1
case "structure":
if v.is_floating_point():
# This is the one hot soft structure token.
return 2
else:
# This is the normal int structure token.
return 1
case "secondary_structure":
return 1
case "sasa":
return 1
case "function":
return 2
case "residue_annotations":
return 2
case "coordinates":
return 3
case _:
raise ValueError(f"Unknown dim for track {k}")
class _BatchedESMProteinTensor(ESMProteinTensor):
@@ -52,7 +66,7 @@ class _BatchedESMProteinTensor(ESMProteinTensor):
def __len__(self) -> int:
def get_len(k, v) -> int:
assert len(v.shape) == _DIMS[k] + 1
assert len(v.shape) == _non_batched_dims(k, v) + 1
return v.size(1)
l = self._detect_attribute(get_len, "length")
@@ -61,18 +75,14 @@ class _BatchedESMProteinTensor(ESMProteinTensor):
@property
def batch_size(self) -> int:
def get_batch_size(k, v) -> int:
assert len(v.shape) == _DIMS[k] + 1
assert len(v.shape) == _non_batched_dims(k, v) + 1
return v.size(0)
d = self._detect_attribute(get_batch_size, "batch size")
assert d is not None
return d
def slice(
self,
i: int,
sequence_len: int | None = None,
) -> ESMProteinTensor:
def slice(self, i: int, sequence_len: int | None = None) -> ESMProteinTensor:
def _maybe_slice(x: torch.Tensor | None):
if x is None:
return None
@@ -130,8 +140,7 @@ def get_default_sampling_config(
def validate_sampling_config(
sampling_config: SamplingConfig,
on_invalid: Literal["raise", "warn"] = "warn",
sampling_config: SamplingConfig, on_invalid: Literal["raise", "warn"] = "warn"
):
# Check that all tracks have topk_logprobs less or equal to MAX_TOP_K
for track in attr.fields(SamplingConfig):
@@ -288,10 +297,7 @@ def sample_sasa_logits(
return sasa_value
def top_p_logits(
logits: torch.Tensor,
top_p: float | torch.Tensor,
) -> torch.Tensor:
def top_p_logits(logits: torch.Tensor, top_p: float | torch.Tensor) -> torch.Tensor:
top_p = _tensorize_like(top_p, logits)
batch_dims = logits.size()[:-1]
@@ -320,9 +326,7 @@ def _tensorize_like(value: int | float | torch.Tensor, logits: torch.Tensor):
def get_sampling_mask(
tokens: torch.Tensor,
sampling_track_config: SamplingTrackConfig,
mask_idx: int,
tokens: torch.Tensor, sampling_track_config: SamplingTrackConfig, mask_idx: int
):
# Do not sample at BOS and EOS tokens
sampling_mask = torch.ones_like(tokens, dtype=torch.bool) # (B, L, )

View File

@@ -31,9 +31,7 @@ def test_sample_logits():
with pytest.raises(ValueError):
sampled = sample_logits(
logits=torch.randn((8, 4096)),
temperature=0.0,
valid_ids=[],
logits=torch.randn((8, 4096)), temperature=0.0, valid_ids=[]
)

View File

@@ -12,15 +12,12 @@ from esm.utils.misc import fp32_autocast_context
@T.runtime_checkable
class Rotation(T.Protocol):
@classmethod
def identity(cls, shape: tuple[int, ...], **tensor_kwargs) -> Self:
...
def identity(cls, shape: tuple[int, ...], **tensor_kwargs) -> Self: ...
@classmethod
def random(cls, shape: tuple[int, ...], **tensor_kwargs) -> Self:
...
def random(cls, shape: tuple[int, ...], **tensor_kwargs) -> Self: ...
def __getitem__(self, idx: T.Any) -> Self:
...
def __getitem__(self, idx: T.Any) -> Self: ...
@property
def tensor(self) -> torch.Tensor:
@@ -35,8 +32,7 @@ class Rotation(T.Protocol):
# This means that 1x4 quaternions are treated as size (1,) for example
...
def as_matrix(self) -> RotationMatrix:
...
def as_matrix(self) -> RotationMatrix: ...
def compose(self, other: Self) -> Self:
# To be safe, we force users to explicitly convert between rotation types.
@@ -50,8 +46,7 @@ class Rotation(T.Protocol):
# rotates points by this rotation object
...
def invert(self) -> Self:
...
def invert(self) -> Self: ...
@property
def dtype(self) -> torch.dtype:
@@ -194,10 +189,7 @@ class Affine3D:
def __getitem__(self, idx: T.Any) -> "Affine3D":
indices = (idx,) if isinstance(idx, int) or idx is None else tuple(idx)
return Affine3D(
trans=self.trans[indices + (slice(None),)],
rot=self.rot[idx],
)
return Affine3D(trans=self.trans[indices + (slice(None),)], rot=self.rot[idx])
@property
def shape(self) -> torch.Size:

View File

@@ -17,8 +17,7 @@ class Alignable(Protocol):
# Trick to detect whether an object is a dataclass
__dataclass_fields__: ClassVar[dict[str, Field[Any]]]
def __len__(self) -> int:
...
def __len__(self) -> int: ...
T = TypeVar("T", bound=Alignable)

View File

@@ -139,18 +139,13 @@ def compute_gdt_ts(
"""
if atom_exists_mask is None:
atom_exists_mask = torch.isfinite(target).all(dim=-1)
(
centered_mobile,
_,
centered_target,
_,
rotation_matrix,
_,
) = compute_alignment_tensors(
mobile=mobile,
target=target,
atom_exists_mask=atom_exists_mask,
sequence_id=sequence_id,
(centered_mobile, _, centered_target, _, rotation_matrix, _) = (
compute_alignment_tensors(
mobile=mobile,
target=target,
atom_exists_mask=atom_exists_mask,
sequence_id=sequence_id,
)
)
# Apply transformation to centered structure

View File

@@ -43,10 +43,7 @@ def get_protein_normalization_frame(coords: Tensor) -> Affine3D:
Affine3D: tensor of Affine3D frame
"""
bb_coords = index_by_atom_name(coords, ["N", "CA", "C"], dim=-2)
coord_mask = torch.all(
torch.all(torch.isfinite(bb_coords), dim=-1),
dim=-1,
)
coord_mask = torch.all(torch.all(torch.isfinite(bb_coords), dim=-1), dim=-1)
average_position_per_n_ca_c = bb_coords.masked_fill(
~coord_mask[..., None, None], 0

View File

@@ -49,11 +49,7 @@ def compute_predicted_aligned_error(
@torch.no_grad
def compute_tm(
logits: torch.Tensor,
aa_mask: torch.Tensor,
max_bin: float = 31.0,
):
def compute_tm(logits: torch.Tensor, aa_mask: torch.Tensor, max_bin: float = 31.0):
square_mask = _compute_pae_masks(aa_mask)
seqlens = aa_mask.sum(-1, keepdim=True)
bins = _pae_bins(max_bin, logits.shape[-1], logits.device)

View File

@@ -229,8 +229,7 @@ class ProteinChain:
return buf.getvalue()
def to_structure_encoder_inputs(
self,
should_normalize_coordinates: bool = True,
self, should_normalize_coordinates: bool = True
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
coords = torch.tensor(self.atom37_positions, dtype=torch.float32)
plddt = torch.tensor(self.confidence, dtype=torch.float32)
@@ -494,9 +493,7 @@ class ProteinChain:
@classmethod
def from_backbone_atom_coordinates(
cls,
backbone_atom_coordinates: np.ndarray | torch.Tensor,
**kwargs,
cls, backbone_atom_coordinates: np.ndarray | torch.Tensor, **kwargs
):
"""Create a ProteinChain from a set of backbone atom coordinates.
@@ -529,10 +526,7 @@ class ProteinChain:
)
atom37_positions[:, :3, :] = backbone_atom_coordinates
return cls.from_atom37(
atom37_positions=atom37_positions,
**kwargs,
)
return cls.from_atom37(atom37_positions=atom37_positions, **kwargs)
@classmethod
def from_pdb(
@@ -586,22 +580,13 @@ class ProteinChain:
num_res = len(sequence)
atom_positions = np.full(
[num_res, RC.atom_type_num, 3],
np.nan,
dtype=np.float32,
)
atom_mask = np.full(
[num_res, RC.atom_type_num],
False,
dtype=bool,
[num_res, RC.atom_type_num, 3], np.nan, dtype=np.float32
)
atom_mask = np.full([num_res, RC.atom_type_num], False, dtype=bool)
residue_index = np.full([num_res], -1, dtype=np.int64)
insertion_code = np.full([num_res], "", dtype="<U4")
confidence = np.ones(
[num_res],
dtype=np.float32,
)
confidence = np.ones([num_res], dtype=np.float32)
for i, res in enumerate(bs.residue_iter(atom_array)):
chain = atom_array[atom_array.chain_id == chain_id]
@@ -639,20 +624,14 @@ class ProteinChain:
)
@classmethod
def from_rcsb(
cls,
pdb_id: str,
chain_id: str = "detect",
):
def from_rcsb(cls, pdb_id: str, chain_id: str = "detect"):
"""Fetch a protein chain from the RCSB PDB database."""
f: io.StringIO = rcsb.fetch(pdb_id, "pdb") # type: ignore
return cls.from_pdb(f, chain_id=chain_id, id=pdb_id)
@classmethod
def from_atomarray(
cls,
atom_array: bs.AtomArray,
id: str | None = None,
cls, atom_array: bs.AtomArray, id: str | None = None
) -> "ProteinChain":
"""A simple converter from bs.AtomArray -> ProteinChain.
Uses PDB file format as intermediate."""

View File

@@ -254,10 +254,7 @@ class ProteinComplex:
return ProteinComplex.from_chains(chains)
@classmethod
def from_rcsb(
cls,
pdb_id: str,
):
def from_rcsb(cls, pdb_id: str):
"""Fetch a protein complex from the RCSB PDB database."""
f: io.StringIO = rcsb.fetch(pdb_id, "pdb") # type: ignore
return cls.from_pdb(f, id=pdb_id)
@@ -345,10 +342,7 @@ class ProteinComplex:
)
@classmethod
def from_chains(
cls,
chains: Sequence[ProteinChain],
):
def from_chains(cls, chains: Sequence[ProteinChain]):
if not chains:
raise ValueError(
"Cannot create a ProteinComplex from an empty list of chains"

View File

@@ -254,10 +254,7 @@ def compute_affine_and_rmsd(
# Apply transformation to centered structure to compute rmsd
rotated_mobile = torch.matmul(centered_mobile, rotation_matrix)
avg_rmsd = compute_rmsd_no_alignment(
rotated_mobile,
centered_target,
num_valid_atoms,
reduction="batch",
rotated_mobile, centered_target, num_valid_atoms, reduction="batch"
)
return affine, avg_rmsd

View File

@@ -116,19 +116,10 @@ def create_function_annotator(
)
delete_button = widgets.Button(
description="Delete",
tooltip="Delete this annotation",
icon="trash",
)
entry = widgets.HBox(
[
delete_button,
widgets.Label(value=function_str),
]
)
delete_button.on_click(
on_delete_click,
description="Delete", tooltip="Delete this annotation", icon="trash"
)
entry = widgets.HBox([delete_button, widgets.Label(value=function_str)])
delete_button.on_click(on_delete_click)
entries.children += (entry,)
except Exception as e:

View File

@@ -150,8 +150,7 @@ def create_sequence_results_page(
sequence_items = []
for item in items:
copy_to_prompt_button = widgets.Button(
description="Copy to Prompt",
disabled=copy_to_prompt_callback is None,
description="Copy to Prompt", disabled=copy_to_prompt_callback is None
)
if copy_to_prompt_callback:
copy_to_prompt_button.on_click(
@@ -190,8 +189,7 @@ def create_sasa_results_page(
sasa_items = []
for item in items:
copy_to_prompt_button = widgets.Button(
description="Copy to Prompt",
disabled=copy_to_prompt_callback is None,
description="Copy to Prompt", disabled=copy_to_prompt_callback is None
)
if copy_to_prompt_callback:
copy_to_prompt_button.on_click(lambda b: copy_to_prompt_callback(item.sasa))
@@ -201,11 +199,7 @@ def create_sasa_results_page(
print("Solvent Accessible Surface Area (SASA) is not available.")
else:
sasa = [s or 0 for s in item.sasa]
draw_data_array(
output,
data_array=sasa,
cmap="Reds",
)
draw_data_array(output, data_array=sasa, cmap="Reds")
if copy_to_prompt_callback:
sasa_items.append(
@@ -227,8 +221,7 @@ def create_secondary_structure_results_page(
ss_items = []
for item in items:
copy_to_prompt_button = widgets.Button(
description="Copy to Prompt",
disabled=copy_to_prompt_callback is None,
description="Copy to Prompt", disabled=copy_to_prompt_callback is None
)
if copy_to_prompt_callback:
copy_to_prompt_button.on_click(
@@ -292,8 +285,7 @@ def create_structure_results_page(
else:
ptm_label = widgets.Label(value=f"pTM: {item.ptm.item():.2f}")
copy_to_prompt_button = widgets.Button(
description="Copy to Prompt",
disabled=copy_to_prompt_callback is None,
description="Copy to Prompt", disabled=copy_to_prompt_callback is None
)
if copy_to_prompt_callback:
copy_to_prompt_button.on_click(
@@ -351,8 +343,7 @@ def create_structure_results_page(
header = widgets.HBox([download_pdb_button, ptm_label])
grid[row, col] = widgets.VBox(
[header, output],
layout={"border": "1px solid gray"},
[header, output], layout={"border": "1px solid gray"}
)
return grid
@@ -364,8 +355,7 @@ def create_function_annotations_results_page(
function_items = []
for item in items:
copy_to_prompt_button = widgets.Button(
description="Copy to Prompt",
disabled=copy_to_prompt_callback is None,
description="Copy to Prompt", disabled=copy_to_prompt_callback is None
)
if copy_to_prompt_callback:
copy_to_prompt_button.on_click(
@@ -387,8 +377,7 @@ def create_function_annotations_results_page(
)
else:
image = draw_function_annotations(
interpro_annotations,
sequence_length=len(item),
interpro_annotations, sequence_length=len(item)
)
if copy_to_prompt_callback:
content = widgets.VBox(

View File

@@ -155,11 +155,7 @@ def get_secondary_structure(protein_chain: ProteinChain) -> Sequence[int]:
def get_ss3_categories():
return [
"Coil (C)",
"Alpha helix (H)",
"Beta strand (E)",
]
return ["Coil (C)", "Alpha helix (H)", "Beta strand (E)"]
def ss3_plot_index_to_letter(ss3_index: int) -> str:

View File

@@ -124,9 +124,9 @@ def create_sequence_prompt_selector(
r, g, b, a = hex_to_rgba_tuple(combined_color)
a = 0.5 # Set alpha to 0.5
combined_color = rgba_tuple_to_rgba_html_string((r, g, b, a))
highlighted_line[
i
] = f'<span style="background-color:{combined_color}">{highlighted_line[i]}</span>'
highlighted_line[i] = (
f'<span style="background-color:{combined_color}">{highlighted_line[i]}</span>'
)
highlighted_lines.append("".join(highlighted_line))
return "<br>".join(highlighted_lines)

View File

@@ -112,10 +112,9 @@ def create_structure_prompt_selector(
).items():
selected_ranges = tuple(selected_ranges) # Convert to hashable
if selected_ranges in contact_map_selection_cache:
(
(x_start, x_end),
(y_start, y_end),
) = contact_map_selection_cache[selected_ranges]
((x_start, x_end), (y_start, y_end)) = contact_map_selection_cache[
selected_ranges
]
rect = Rectangle(
(x_start - 0.5, max_y - y_end - 1.5),
x_end - x_start + 1,

View File

@@ -24,7 +24,5 @@ def get_forge_client(model_name: str) -> ESM3InferenceClient:
"Forge API key not found. Please set the ESM_API_KEY environment variable."
)
return ESM3ForgeInferenceClient(
model=model_name,
url="https://forge.evolutionaryscale.ai",
token=forge_token,
model=model_name, url="https://forge.evolutionaryscale.ai", token=forge_token
)

View File

@@ -110,8 +110,7 @@ def draw_data_array(
else:
legend_patches = [
patches.Patch(
color=rgb_colors[category_to_index[category]],
label=category,
color=rgb_colors[category_to_index[category]], label=category
)
for category in categories
]

View File

@@ -26,9 +26,7 @@ def use_backend(backend):
def draw_function_annotations(
annotations: list[FunctionAnnotation],
sequence_length: int,
interpro_=InterPro(),
annotations: list[FunctionAnnotation], sequence_length: int, interpro_=InterPro()
) -> widgets.Image:
cmap = colormaps["tab10"]
colors = [cmap(i) for i in range(len(InterProEntryType))]
@@ -63,9 +61,7 @@ def draw_function_annotations(
with use_backend("agg"):
fig, ax = plt.subplots()
record = GraphicRecord(
sequence=None,
sequence_length=sequence_length,
features=features,
sequence=None, sequence_length=sequence_length, features=features
)
record.plot(ax=ax, plot_sequence=False)
fig.savefig(buf, format="png", dpi=200, bbox_inches="tight")

View File

@@ -19,8 +19,7 @@ def draw_protein_structure(
for start, end, color in highlighted_ranges:
view.setStyle(
{"resi": str(start) + "-" + str(end)},
{"cartoon": {"color": color}},
{"resi": str(start) + "-" + str(end)}, {"cartoon": {"color": color}}
)
view.zoomTo()

View File

@@ -8,18 +8,13 @@ PDB_INDEX = "PDB index"
PDB_INDEX_SUFFIX = "[PDB Index]"
def get_pdb_index_min_max(
protein_chain: ProteinChain,
) -> tuple[int, int]:
def get_pdb_index_min_max(protein_chain: ProteinChain) -> tuple[int, int]:
residue_index = protein_chain.residue_index
valid_residue_index = residue_index[residue_index != -1]
return min(valid_residue_index), max(valid_residue_index)
def pdb_index_to_zero_index(
residue_index: int,
protein_chain: ProteinChain,
) -> int:
def pdb_index_to_zero_index(residue_index: int, protein_chain: ProteinChain) -> int:
# Find the first position equal to residue_index
pos = np.argwhere(residue_index == protein_chain.residue_index)
if len(pos) == 0:
@@ -27,16 +22,12 @@ def pdb_index_to_zero_index(
return pos[0][0]
def zero_index_to_pdb_index(
zero_index: int,
protein_chain: ProteinChain,
) -> int:
def zero_index_to_pdb_index(zero_index: int, protein_chain: ProteinChain) -> int:
return protein_chain.residue_index[zero_index]
def zero_range_to_pdb_range(
zero_range: tuple[int, int],
protein_chain: ProteinChain,
zero_range: tuple[int, int], protein_chain: ProteinChain
) -> tuple[int, int]:
return (
zero_index_to_pdb_index(zero_range[0], protein_chain),
@@ -45,8 +36,7 @@ def zero_range_to_pdb_range(
def pdb_range_to_zero_range(
pdb_range: tuple[int, int],
protein_chain: ProteinChain,
pdb_range: tuple[int, int], protein_chain: ProteinChain
) -> tuple[int, int]:
return (
pdb_index_to_zero_index(pdb_range[0], protein_chain),

View File

@@ -196,9 +196,7 @@ class PromptManager:
def redraw(self, change=None):
categories = ["Mask (-)"]
color_map = {
"Mask (-)": "white",
}
color_map = {"Mask (-)": "white"}
data_array = [0] * self.prompt_length
for prompt_str, *_ in self.prompts.items():
color, _, _ = self.prompts[prompt_str]
@@ -282,7 +280,7 @@ class PromptManager:
value=(
f'<div style="display: inline-block; width: 10px; height: 10px; background-color:{label_color}; margin-right: 5px;"></div>'
f"{range_string}"
),
)
)
entry_label.tag = range_string # type: ignore
entry_container = widgets.HBox([entry_button, entry_label])

View File

@@ -9,11 +9,7 @@ from esm.widgets.utils.printing import wrapped_print
class ProteinImporter:
def __init__(
self,
max_proteins: int | None = None,
autoload: bool = False,
) -> None:
def __init__(self, max_proteins: int | None = None, autoload: bool = False) -> None:
self._protein_list: list[tuple[str, ProteinChain]] = []
self._protein_workspace: dict[str, str] = {}
self.max_proteins = max_proteins

View File

@@ -55,9 +55,7 @@ def create_download_results_button(
)
def serialize_protein(
protein: ESMProtein,
) -> str:
def serialize_protein(protein: ESMProtein) -> str:
protein_dict = {
"sequence": protein.sequence,
"coordinates": protein.coordinates.tolist()

View File

@@ -125,9 +125,7 @@ def create_esm3_generation_launcher(
]
)
generation_config_ui = widgets.VBox(
[generation_config_settings_ui],
)
generation_config_ui = widgets.VBox([generation_config_settings_ui])
def on_track_change(change):
if change["new"] == "function":

View File

@@ -12,9 +12,7 @@ from esm.widgets.components.sequence_prompt_selector import (
from esm.widgets.components.structure_prompt_selector import (
create_structure_prompt_selector,
)
from esm.widgets.utils.prompting import (
PromptManagerCollection,
)
from esm.widgets.utils.prompting import PromptManagerCollection
from esm.widgets.utils.protein_import import ProteinImporter

View File

@@ -2,20 +2,13 @@ from typing import Any, Literal
from ipywidgets import widgets
from esm.sdk.api import (
ESM3InferenceClient,
ESMProtein,
)
from esm.sdk.api import ESM3InferenceClient, ESMProtein
from esm.utils.constants import esm3 as C
from esm.widgets.components.function_annotator import (
create_function_annotator,
)
from esm.widgets.utils.prompting import (
PromptManagerCollection,
)
from esm.widgets.utils.protein_import import (
ProteinImporter,
)
from esm.widgets.utils.prompting import PromptManagerCollection
from esm.widgets.utils.protein_import import ProteinImporter
from esm.widgets.views.esm3_generation_launcher import (
create_esm3_generation_launcher,
)
@@ -47,14 +40,9 @@ def create_generation_ui(
protein_length_ui = widgets.VBox(
[
widgets.HTML(value="<h3>Specify Prompt Length:</h3>"),
widgets.HBox(
[
protein_length_input,
protein_length_confirm_button,
]
),
widgets.HBox([protein_length_input, protein_length_confirm_button]),
output,
],
]
)
loading_ui = widgets.HTML(value="<h3>Loading...</h3>")
@@ -117,10 +105,7 @@ def create_generation_ui(
add_annotation_callback=prompt_manager_collection.add_function_annotation,
delete_annotation_callback=prompt_manager_collection.delete_function_annotation,
)
function_annotator_ui.children = [
function_annotator_title,
function_annotator,
]
function_annotator_ui.children = [function_annotator_title, function_annotator]
if len(protein_importer.protein_list) == 0:
prompt_ui.children = [
@@ -139,10 +124,7 @@ def create_generation_ui(
esm3_selector_ui = create_esm3_prompt_selector(
prompt_manager_collection, protein_importer=protein_importer
)
selector_ui.children = [
selector_title,
esm3_selector_ui,
]
selector_ui.children = [selector_title, esm3_selector_ui]
prompt_ui.children = [
protein_importer_ui,
protein_length_ui,
@@ -184,10 +166,7 @@ def create_generation_ui(
copy_to_prompt_callback=copy_to_prompt_callback,
)
generation_launcher_ui = widgets.VBox(
[
widgets.HTML(value="<h3>Generation Config:</h3>"),
generation_launcher,
]
[widgets.HTML(value="<h3>Generation Config:</h3>"), generation_launcher]
)
if len(protein_importer.protein_list) > 0:

View File

@@ -10,9 +10,7 @@ from esm.widgets.components.results_visualizer import (
create_results_visualizer,
)
from esm.widgets.utils.printing import wrapped_print
from esm.widgets.utils.protein_import import (
ProteinImporter,
)
from esm.widgets.utils.protein_import import ProteinImporter
def create_inverse_folding_ui(client: ESM3InferenceClient) -> widgets.Widget:

View File

@@ -133,10 +133,7 @@ def create_login_ui(client_container: ClientInitContainer):
start_msg_output,
]
elif change["new"] == "Local":
model_selection_ui.children = [
model_selection_header,
local_model,
]
model_selection_ui.children = [model_selection_header, local_model]
login_ui.children = [
infobox,
selection_ui,

View File

@@ -10,9 +10,7 @@ from esm.widgets.components.results_visualizer import (
create_results_visualizer,
)
from esm.widgets.utils.printing import wrapped_print
from esm.widgets.utils.protein_import import (
ProteinImporter,
)
from esm.widgets.utils.protein_import import ProteinImporter
def create_prediction_ui(client: ESM3InferenceClient) -> widgets.Widget:
@@ -85,11 +83,7 @@ def create_prediction_ui(client: ESM3InferenceClient) -> widgets.Widget:
try:
# Reset the output and results
output.clear_output()
prediction_ui.children = [
input_ui,
predict_button,
output,
]
prediction_ui.children = [input_ui, predict_button, output]
# Predict the protein's properties
with output:
protein = get_protein()
@@ -159,19 +153,10 @@ def create_prediction_ui(client: ESM3InferenceClient) -> widgets.Widget:
wrapped_print(e)
predict_button.on_click(on_click_predict)
protein_importer.entries_box.observe(
on_new_protein,
names="children",
)
protein_importer.entries_box.observe(on_new_protein, names="children")
protein_importer.register_delete_callback(lambda: validate_predict(None))
sequence_input_ui.children[1].observe(
on_new_sequence,
names="value",
)
input_ui.observe(
validate_predict,
names="selected_index",
)
sequence_input_ui.children[1].observe(on_new_sequence, names="value")
input_ui.observe(validate_predict, names="selected_index")
return prediction_ui

View File

@@ -7,8 +7,9 @@ from esm.sdk.api import (
ESM3InferenceClient,
ESMProtein,
GenerationConfig,
InverseFoldingConfig,
)
from esm.sdk.forge import FoldForgeInferenceClient
from esm.sdk.forge import SequenceStructureForgeInferenceClient
def convert_none_to_nan(data):
@@ -21,46 +22,53 @@ def convert_none_to_nan(data):
return data
def are_allclose_with_nan(A, B, rtol=1e-5, atol=1e-2):
B = convert_none_to_nan(B)
A = np.array(A)
B = np.array(B)
if A.shape != B.shape:
raise ValueError("A and B must have the same shape")
nan_mask_A = np.isnan(A)
nan_mask_B = np.isnan(B)
if not np.array_equal(nan_mask_A, nan_mask_B):
return False
return np.allclose(A[~nan_mask_A], B[~nan_mask_B], rtol=rtol, atol=atol)
def main(fold_client: FoldForgeInferenceClient, esm3_client: ESM3InferenceClient):
# Folding
def main(
sequence_structure_client: SequenceStructureForgeInferenceClient,
esm3_client: ESM3InferenceClient,
):
# Folding with esm3 client
protein = get_sample_protein()
sequence_length = len(protein.sequence) # type: ignore
num_steps = int(sequence_length / 16)
protein.coordinates = None
protein.function_annotations = None
protein.sasa = None
assert protein.sequence is not None, "Protein sequence must be set to fold"
# Folding with esm3 client
folded_protein = cast(
ESMProtein,
esm3_client.generate(
protein,
GenerationConfig(
track="structure", schedule="cosine", num_steps=num_steps, temperature=0
),
),
)
config = GenerationConfig(track="structure", num_steps=1, temperature=0)
esm3_client_folded_protein = esm3_client.generate(protein, config)
assert isinstance(
esm3_client_folded_protein, ESMProtein
), f"Using ESM3 client, ESMProtein was expected but got {protein}"
# Folding with folding client
coordinates = fold_client.fold(
"esm3",
protein.sequence, # type:ignore
potential_sequence_of_concern=False,
sequence_structure_client_folded_protein = sequence_structure_client.fold(
protein.sequence, potential_sequence_of_concern=False
)
assert are_allclose_with_nan(folded_protein.coordinates, coordinates)
assert isinstance(
sequence_structure_client_folded_protein, ESMProtein
), f"Using sequence_structure client, ESMProtein was expected but got {sequence_structure_client_folded_protein}"
# Inverse Folding with esm3 client
protein = get_sample_protein()
protein.sequence = None
protein.sasa = None
protein.function_annotations = None
assert (
protein.coordinates is not None
), "Protein coordinates must be set to inverse fold"
config = GenerationConfig("sequence", num_steps=1, temperature=0.7)
esm3_client_inv_folded_protein = cast(
ESMProtein, esm3_client.generate(protein, config)
)
assert isinstance(
esm3_client_inv_folded_protein, ESMProtein
), f"Using ESM3 client, ESMProtein was expected but got {protein}"
# Inverse Folding with inverse folding client
sequence_structure_client_inv_folded_protein = (
sequence_structure_client.inverse_fold(
protein.coordinates,
config=InverseFoldingConfig(temperature=0.7),
potential_sequence_of_concern=False,
)
)
assert isinstance(
sequence_structure_client_inv_folded_protein, ESMProtein
), f"Using sequence_structure client, ESMProtein was expected but got {sequence_structure_client_inv_folded_protein}"

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,3 +1,5 @@
import torch
from esm.models.esm3 import ESM3
from esm.sdk.api import (
ESM3InferenceClient,
@@ -38,8 +40,7 @@ def main(client: ESM3InferenceClient):
protein.function_annotations = None
protein = client.encode(protein)
single_step_protein = client.forward_and_sample(
protein,
SamplingConfig(structure=SamplingTrackConfig(topk_logprobs=2)),
protein, SamplingConfig(structure=SamplingTrackConfig(topk_logprobs=2))
)
single_step_protein.protein_tensor.sequence = protein.sequence
single_step_protein = client.decode(single_step_protein.protein_tensor)
@@ -52,8 +53,7 @@ def main(client: ESM3InferenceClient):
)
protein = ESMProtein(sequence=prompt)
protein = client.generate(
protein,
GenerationConfig(track="sequence", num_steps=8, temperature=0.7),
protein, GenerationConfig(track="sequence", num_steps=8, temperature=0.7)
)
assert isinstance(protein, ESMProtein), f"ESMProtein was expected but got {protein}"
@@ -189,5 +189,6 @@ def main(client: ESM3InferenceClient):
assert isinstance(p, ESMProtein), f"ESMProtein was expected but got {p}"
if __name__ == "__main__":
main(ESM3.from_pretrained("esm3_sm_open_v1"))

View File

@@ -1,6 +1,6 @@
[project]
name = "esm"
version = "3.0.7post1"
version = "3.0.8"
description = "EvolutionaryScale open model repository"
readme = "README.md"
requires-python = ">=3.10"

View File

@@ -53,9 +53,7 @@
"outputs": [],
"source": [
"from esm.widgets.utils.types import ClientInitContainer\n",
"from esm.widgets.views.inverse_folding import (\n",
" create_inverse_folding_ui,\n",
")\n",
"from esm.widgets.views.inverse_folding import create_inverse_folding_ui\n",
"from esm.widgets.views.login import create_login_ui"
]
},