mirror of
https://github.com/evolutionaryscale/esm.git
synced 2026-06-04 17:14:23 +08:00
Sync from internal
This commit is contained in:
@@ -1 +1 @@
|
||||
__version__ = "0.2rc1"
|
||||
__version__ = "3.0.1"
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from functools import partial
|
||||
from typing import Callable
|
||||
|
||||
import attr
|
||||
import einops
|
||||
@@ -28,7 +29,9 @@ from esm.sdk.api import (
|
||||
ProteinType,
|
||||
SamplingConfig,
|
||||
)
|
||||
from esm.tokenization import get_model_tokenizers
|
||||
from esm.tokenization import (
|
||||
TokenizerCollectionProtocol,
|
||||
)
|
||||
from esm.utils import encoding
|
||||
from esm.utils.constants import esm3 as C
|
||||
from esm.utils.constants.models import ESM3_OPEN_SMALL
|
||||
@@ -202,9 +205,10 @@ class ESM3(nn.Module, ESM3InferenceClient):
|
||||
n_heads: int,
|
||||
v_heads: int,
|
||||
n_layers: int,
|
||||
structure_encoder_name: str,
|
||||
structure_decoder_name: str,
|
||||
function_decoder_name: str,
|
||||
structure_encoder_fn: Callable[[torch.device | str], StructureTokenEncoder],
|
||||
structure_decoder_fn: Callable[[torch.device | str], StructureTokenDecoder],
|
||||
function_decoder_fn: Callable[[torch.device | str], FunctionTokenDecoder],
|
||||
tokenizers: TokenizerCollectionProtocol,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = EncodeInputs(d_model)
|
||||
@@ -217,15 +221,15 @@ class ESM3(nn.Module, ESM3InferenceClient):
|
||||
)
|
||||
self.output_heads = OutputHeads(d_model)
|
||||
|
||||
self.structure_encoder_name = structure_encoder_name
|
||||
self.structure_decoder_name = structure_decoder_name
|
||||
self.function_decoder_name = function_decoder_name
|
||||
self.structure_encoder_fn = structure_encoder_fn
|
||||
self.structure_decoder_fn = structure_decoder_fn
|
||||
self.function_decoder_fn = function_decoder_fn
|
||||
|
||||
self.structure_encoder: StructureTokenEncoder | None = None
|
||||
self.structure_decoder: StructureTokenDecoder | None = None
|
||||
self.function_decoder: FunctionTokenDecoder | None = None
|
||||
self._structure_encoder = None
|
||||
self._structure_decoder = None
|
||||
self._function_decoder = None
|
||||
|
||||
self.tokenizers = get_model_tokenizers(ESM3_OPEN_SMALL)
|
||||
self.tokenizers = tokenizers
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
@@ -245,32 +249,24 @@ class ESM3(nn.Module, ESM3InferenceClient):
|
||||
assert isinstance(model, ESM3)
|
||||
return model
|
||||
|
||||
def get_structure_token_encoder(self) -> StructureTokenEncoder:
|
||||
if self.structure_encoder is None:
|
||||
model = self.load_model(self.structure_encoder_name)
|
||||
assert isinstance(model, StructureTokenEncoder)
|
||||
self.structure_encoder = model
|
||||
return self.structure_encoder
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
def get_structure_token_decoder(self) -> StructureTokenDecoder:
|
||||
if self.structure_decoder is None:
|
||||
model = self.load_model(self.structure_decoder_name)
|
||||
assert isinstance(model, StructureTokenDecoder)
|
||||
self.structure_decoder = model
|
||||
return self.structure_decoder
|
||||
def get_structure_encoder(self) -> StructureTokenEncoder:
|
||||
if self._structure_encoder is None:
|
||||
self._structure_encoder = self.structure_encoder_fn(self.device)
|
||||
return self._structure_encoder
|
||||
|
||||
def get_function_token_decoder(self) -> FunctionTokenDecoder:
|
||||
if self.function_decoder is None:
|
||||
model = self.load_model(self.function_decoder_name)
|
||||
assert isinstance(model, FunctionTokenDecoder)
|
||||
self.function_decoder = model
|
||||
return self.function_decoder
|
||||
def get_structure_decoder(self) -> StructureTokenDecoder:
|
||||
if self._structure_decoder is None:
|
||||
self._structure_decoder = self.structure_decoder_fn(self.device)
|
||||
return self._structure_decoder
|
||||
|
||||
def load_model(self, model_name: str):
|
||||
# Lazy import from pretrained
|
||||
from esm.pretrained import load_local_model
|
||||
|
||||
return load_local_model(model_name, device=next(self.parameters()).device)
|
||||
def get_function_decoder(self) -> FunctionTokenDecoder:
|
||||
if self._function_decoder is None:
|
||||
self._function_decoder = self.function_decoder_fn(self.device)
|
||||
return self._function_decoder
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -360,15 +356,10 @@ class ESM3(nn.Module, ESM3InferenceClient):
|
||||
] # In case we pass in an atom14 or atom37 repr
|
||||
affine, affine_mask = build_affine3d_from_coordinates(structure_coords)
|
||||
|
||||
if structure_tokens is None:
|
||||
_, structure_tokens = self.get_structure_token_encoder().encode(
|
||||
structure_coords
|
||||
)
|
||||
structure_tokens = defaults(structure_tokens, C.STRUCTURE_MASK_TOKEN)
|
||||
assert structure_tokens is not None
|
||||
structure_tokens = (
|
||||
structure_tokens.masked_fill(
|
||||
(structure_tokens == -1) | ~affine_mask, C.STRUCTURE_MASK_TOKEN
|
||||
)
|
||||
structure_tokens.masked_fill(structure_tokens == -1, C.STRUCTURE_MASK_TOKEN)
|
||||
.masked_fill(sequence_tokens == C.SEQUENCE_BOS_TOKEN, C.STRUCTURE_BOS_TOKEN)
|
||||
.masked_fill(sequence_tokens == C.SEQUENCE_PAD_TOKEN, C.STRUCTURE_PAD_TOKEN)
|
||||
.masked_fill(sequence_tokens == C.SEQUENCE_EOS_TOKEN, C.STRUCTURE_EOS_TOKEN)
|
||||
@@ -405,7 +396,7 @@ class ESM3(nn.Module, ESM3InferenceClient):
|
||||
configs
|
||||
), "Must have the same number of prompts and configs."
|
||||
|
||||
if inputs is []:
|
||||
if inputs == []:
|
||||
# Nothing to do.
|
||||
return []
|
||||
|
||||
@@ -469,7 +460,7 @@ class ESM3(nn.Module, ESM3InferenceClient):
|
||||
if input.coordinates is not None:
|
||||
coordinates, _, structure_tokens = encoding.tokenize_structure(
|
||||
input.coordinates,
|
||||
self.get_structure_token_encoder(),
|
||||
self.get_structure_encoder(),
|
||||
structure_tokenizer=self.tokenizers.structure,
|
||||
reference_sequence=input.sequence or "",
|
||||
add_special_tokens=True,
|
||||
@@ -517,8 +508,8 @@ class ESM3(nn.Module, ESM3InferenceClient):
|
||||
return decode_protein_tensor(
|
||||
input=input,
|
||||
tokenizers=self.tokenizers,
|
||||
structure_token_decoder=self.get_structure_token_decoder(),
|
||||
function_token_decoder=self.get_function_token_decoder(),
|
||||
structure_token_decoder=self.get_structure_decoder(),
|
||||
function_token_decoder=self.get_function_decoder(),
|
||||
)
|
||||
|
||||
def _forward(
|
||||
|
||||
@@ -9,6 +9,7 @@ from esm.models.vqvae import (
|
||||
StructureTokenDecoder,
|
||||
StructureTokenEncoder,
|
||||
)
|
||||
from esm.tokenization import get_model_tokenizers
|
||||
from esm.utils.constants.esm3 import data_root
|
||||
from esm.utils.constants.models import (
|
||||
ESM3_FUNCTION_DECODER_V0,
|
||||
@@ -20,24 +21,6 @@ from esm.utils.constants.models import (
|
||||
ModelBuilder = Callable[[torch.device | str], nn.Module]
|
||||
|
||||
|
||||
def ESM3_sm_open_v0(device: torch.device | str = "cpu"):
|
||||
with torch.device(device):
|
||||
model = ESM3(
|
||||
d_model=1536,
|
||||
n_heads=24,
|
||||
v_heads=256,
|
||||
n_layers=48,
|
||||
structure_encoder_name=ESM3_STRUCTURE_ENCODER_V0,
|
||||
structure_decoder_name=ESM3_STRUCTURE_DECODER_V0,
|
||||
function_decoder_name=ESM3_FUNCTION_DECODER_V0,
|
||||
).eval()
|
||||
state_dict = torch.load(
|
||||
data_root() / "data/weights/esm3_sm_open_v1.pth", map_location=device
|
||||
)
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
|
||||
def ESM3_structure_encoder_v0(device: torch.device | str = "cpu"):
|
||||
with torch.device(device):
|
||||
model = StructureTokenEncoder(
|
||||
@@ -70,6 +53,25 @@ def ESM3_function_decoder_v0(device: torch.device | str = "cpu"):
|
||||
return model
|
||||
|
||||
|
||||
def ESM3_sm_open_v0(device: torch.device | str = "cpu"):
|
||||
with torch.device(device):
|
||||
model = ESM3(
|
||||
d_model=1536,
|
||||
n_heads=24,
|
||||
v_heads=256,
|
||||
n_layers=48,
|
||||
structure_encoder_fn=ESM3_structure_encoder_v0,
|
||||
structure_decoder_fn=ESM3_structure_decoder_v0,
|
||||
function_decoder_fn=ESM3_function_decoder_v0,
|
||||
tokenizers=get_model_tokenizers(ESM3_OPEN_SMALL),
|
||||
).eval()
|
||||
state_dict = torch.load(
|
||||
data_root() / "data/weights/esm3_sm_open_v1.pth", map_location=device
|
||||
)
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
|
||||
LOCAL_MODEL_REGISTRY: dict[str, ModelBuilder] = {
|
||||
ESM3_OPEN_SMALL: ESM3_sm_open_v0,
|
||||
ESM3_STRUCTURE_ENCODER_V0: ESM3_structure_encoder_v0,
|
||||
|
||||
@@ -193,7 +193,9 @@ class GenerationConfig:
|
||||
track: str = ""
|
||||
invalid_ids: Sequence[int] = []
|
||||
schedule: str = "cosine"
|
||||
num_steps: int = 8
|
||||
# Set this to a higher value for better generation results.
|
||||
# Note that this needs to be less than or equal to the sequence length.
|
||||
num_steps: int = 1
|
||||
temperature: float = 1.0
|
||||
top_p: float = 1.0
|
||||
condition_on_coordinates_only: bool = True
|
||||
|
||||
@@ -61,6 +61,11 @@ def decode_protein_tensor(
|
||||
track_tokenizer = getattr(tokenizers, track.name)
|
||||
if torch.all(tokens == track_tokenizer.pad_token_id):
|
||||
setattr(input, track.name, None)
|
||||
# If structure track has any mask tokens, do not decode.
|
||||
if track.name == "structure" and torch.any(
|
||||
tokens == track_tokenizer.mask_token_id
|
||||
):
|
||||
setattr(input, track.name, None)
|
||||
|
||||
if input.sequence is not None:
|
||||
sequence = decode_sequence(input.sequence, tokenizers.sequence)
|
||||
|
||||
@@ -104,7 +104,7 @@ def tokenize_structure(
|
||||
structure_tokens = F.pad(
|
||||
structure_tokens,
|
||||
(left_pad, right_pad),
|
||||
value=structure_tokenizer.pad_token_id,
|
||||
value=structure_tokenizer.mask_token_id,
|
||||
)
|
||||
structure_tokens[0] = structure_tokenizer.bos_token_id
|
||||
structure_tokens[-1] = structure_tokenizer.eos_token_id
|
||||
@@ -186,7 +186,7 @@ def get_default_structure_tokens(
|
||||
(sequence_length + 2,),
|
||||
dtype=torch.int64,
|
||||
)
|
||||
* structure_tokenizer.pad_token_id
|
||||
* structure_tokenizer.mask_token_id
|
||||
)
|
||||
# Always include BOS and EOS tokens
|
||||
structure_tokens[0] = structure_tokenizer.bos_token_id
|
||||
|
||||
@@ -136,6 +136,8 @@ def _make_masked_inputs(
|
||||
|
||||
if track == "coordinates":
|
||||
dims = (sequence_length, 3, 3)
|
||||
elif track == "confidence":
|
||||
dims = (sequence_length,)
|
||||
elif track == "attention_mask":
|
||||
dims = (sequence_length,)
|
||||
elif track == "function":
|
||||
@@ -147,6 +149,9 @@ def _make_masked_inputs(
|
||||
|
||||
if track == "coordinates":
|
||||
masked_tokens = torch.full(dims, torch.inf, dtype=torch.float)
|
||||
elif track == "confidence":
|
||||
# All-mask dummy input for confidence track.
|
||||
masked_tokens = torch.full(dims, 0.0)
|
||||
elif track == "attention_mask":
|
||||
masked_tokens = torch.full(dims, 1, dtype=torch.bool)
|
||||
else:
|
||||
@@ -302,7 +307,7 @@ def iterative_sampling_tokens(
|
||||
sequence_lengths = [len(tokens) for tokens in sampled_tokens]
|
||||
# Figure out the number of tokens to be sampled for each prompt.
|
||||
total_to_sample = []
|
||||
for protein, seq_len, config in zip(input_tokens, sequence_lengths, configs):
|
||||
for protein, seq_len, config in zip(sampled_tokens, sequence_lengths, configs):
|
||||
track = config.track
|
||||
|
||||
if getattr(protein, track) is None:
|
||||
@@ -324,7 +329,7 @@ def iterative_sampling_tokens(
|
||||
|
||||
# Now stack the list to make a single batched ESMProteinTensor.
|
||||
batched_tokens = _stack_protein_tensors(
|
||||
input_tokens,
|
||||
sampled_tokens,
|
||||
sequence_lengths,
|
||||
tokenizers,
|
||||
devices.pop(),
|
||||
@@ -365,8 +370,8 @@ def iterative_sampling_tokens(
|
||||
forward_out, i, keep_dim=True
|
||||
)
|
||||
# Trim logits to proper sequence length for this prompt.
|
||||
per_prompt_forward_out.logits = _trim_sequence_tensor_dataclass(
|
||||
per_prompt_forward_out.logits,
|
||||
per_prompt_forward_out = _trim_sequence_tensor_dataclass(
|
||||
per_prompt_forward_out,
|
||||
# Note(jungong) : we can not smiply use sequence_lenths[i] here,
|
||||
# what we want is for the sequence length of the logits to match
|
||||
# that of the prompt, which may or may not be padded, depending on
|
||||
@@ -568,7 +573,7 @@ def _sample_per_prompt(
|
||||
)
|
||||
mean_embedding = (
|
||||
# [B, L, D] -> [B, D]
|
||||
forward_output.embeddings[0].mean(dim=1) # type: ignore
|
||||
forward_output.embeddings.mean(dim=1) # type: ignore
|
||||
if sampling_config.return_mean_embedding
|
||||
else None
|
||||
)
|
||||
|
||||
@@ -8,7 +8,7 @@ from esm.sdk.api import (
|
||||
SamplingTrackConfig,
|
||||
)
|
||||
from esm.tokenization import (
|
||||
TokenizerCollection,
|
||||
TokenizerCollectionProtocol,
|
||||
get_invalid_tokenizer_ids,
|
||||
)
|
||||
from esm.tokenization.function_tokenizer import (
|
||||
@@ -102,7 +102,9 @@ class _BatchedESMProteinTensor(ESMProteinTensor):
|
||||
s[i, ...] = v
|
||||
|
||||
|
||||
def get_default_sampling_config(tokenizers: TokenizerCollection) -> SamplingConfig:
|
||||
def get_default_sampling_config(
|
||||
tokenizers: TokenizerCollectionProtocol,
|
||||
) -> SamplingConfig:
|
||||
tracks = [f.name for f in attr.fields(SamplingConfig)]
|
||||
sampling_config = SamplingConfig()
|
||||
for current_track in tracks:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "esm"
|
||||
version = "3.0.0"
|
||||
version = "3.0.1"
|
||||
description = "EvolutionaryScale open model repository"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
Reference in New Issue
Block a user