Sync from internal

This commit is contained in:
Zeming Lin
2024-07-18 16:54:35 +00:00
parent 00cdadfffa
commit 17d48878a9
9 changed files with 82 additions and 75 deletions

View File

@@ -1 +1 @@
__version__ = "0.2rc1"
__version__ = "3.0.1"

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import contextlib
from functools import partial
from typing import Callable
import attr
import einops
@@ -28,7 +29,9 @@ from esm.sdk.api import (
ProteinType,
SamplingConfig,
)
from esm.tokenization import get_model_tokenizers
from esm.tokenization import (
TokenizerCollectionProtocol,
)
from esm.utils import encoding
from esm.utils.constants import esm3 as C
from esm.utils.constants.models import ESM3_OPEN_SMALL
@@ -202,9 +205,10 @@ class ESM3(nn.Module, ESM3InferenceClient):
n_heads: int,
v_heads: int,
n_layers: int,
structure_encoder_name: str,
structure_decoder_name: str,
function_decoder_name: str,
structure_encoder_fn: Callable[[torch.device | str], StructureTokenEncoder],
structure_decoder_fn: Callable[[torch.device | str], StructureTokenDecoder],
function_decoder_fn: Callable[[torch.device | str], FunctionTokenDecoder],
tokenizers: TokenizerCollectionProtocol,
):
super().__init__()
self.encoder = EncodeInputs(d_model)
@@ -217,15 +221,15 @@ class ESM3(nn.Module, ESM3InferenceClient):
)
self.output_heads = OutputHeads(d_model)
self.structure_encoder_name = structure_encoder_name
self.structure_decoder_name = structure_decoder_name
self.function_decoder_name = function_decoder_name
self.structure_encoder_fn = structure_encoder_fn
self.structure_decoder_fn = structure_decoder_fn
self.function_decoder_fn = function_decoder_fn
self.structure_encoder: StructureTokenEncoder | None = None
self.structure_decoder: StructureTokenDecoder | None = None
self.function_decoder: FunctionTokenDecoder | None = None
self._structure_encoder = None
self._structure_decoder = None
self._function_decoder = None
self.tokenizers = get_model_tokenizers(ESM3_OPEN_SMALL)
self.tokenizers = tokenizers
@classmethod
def from_pretrained(
@@ -245,32 +249,24 @@ class ESM3(nn.Module, ESM3InferenceClient):
assert isinstance(model, ESM3)
return model
def get_structure_token_encoder(self) -> StructureTokenEncoder:
if self.structure_encoder is None:
model = self.load_model(self.structure_encoder_name)
assert isinstance(model, StructureTokenEncoder)
self.structure_encoder = model
return self.structure_encoder
@property
def device(self):
return next(self.parameters()).device
def get_structure_token_decoder(self) -> StructureTokenDecoder:
if self.structure_decoder is None:
model = self.load_model(self.structure_decoder_name)
assert isinstance(model, StructureTokenDecoder)
self.structure_decoder = model
return self.structure_decoder
def get_structure_encoder(self) -> StructureTokenEncoder:
if self._structure_encoder is None:
self._structure_encoder = self.structure_encoder_fn(self.device)
return self._structure_encoder
def get_function_token_decoder(self) -> FunctionTokenDecoder:
if self.function_decoder is None:
model = self.load_model(self.function_decoder_name)
assert isinstance(model, FunctionTokenDecoder)
self.function_decoder = model
return self.function_decoder
def get_structure_decoder(self) -> StructureTokenDecoder:
if self._structure_decoder is None:
self._structure_decoder = self.structure_decoder_fn(self.device)
return self._structure_decoder
def load_model(self, model_name: str):
# Lazy import from pretrained
from esm.pretrained import load_local_model
return load_local_model(model_name, device=next(self.parameters()).device)
def get_function_decoder(self) -> FunctionTokenDecoder:
if self._function_decoder is None:
self._function_decoder = self.function_decoder_fn(self.device)
return self._function_decoder
def forward(
self,
@@ -360,15 +356,10 @@ class ESM3(nn.Module, ESM3InferenceClient):
] # In case we pass in an atom14 or atom37 repr
affine, affine_mask = build_affine3d_from_coordinates(structure_coords)
if structure_tokens is None:
_, structure_tokens = self.get_structure_token_encoder().encode(
structure_coords
)
structure_tokens = defaults(structure_tokens, C.STRUCTURE_MASK_TOKEN)
assert structure_tokens is not None
structure_tokens = (
structure_tokens.masked_fill(
(structure_tokens == -1) | ~affine_mask, C.STRUCTURE_MASK_TOKEN
)
structure_tokens.masked_fill(structure_tokens == -1, C.STRUCTURE_MASK_TOKEN)
.masked_fill(sequence_tokens == C.SEQUENCE_BOS_TOKEN, C.STRUCTURE_BOS_TOKEN)
.masked_fill(sequence_tokens == C.SEQUENCE_PAD_TOKEN, C.STRUCTURE_PAD_TOKEN)
.masked_fill(sequence_tokens == C.SEQUENCE_EOS_TOKEN, C.STRUCTURE_EOS_TOKEN)
@@ -405,7 +396,7 @@ class ESM3(nn.Module, ESM3InferenceClient):
configs
), "Must have the same number of prompts and configs."
if inputs is []:
if inputs == []:
# Nothing to do.
return []
@@ -469,7 +460,7 @@ class ESM3(nn.Module, ESM3InferenceClient):
if input.coordinates is not None:
coordinates, _, structure_tokens = encoding.tokenize_structure(
input.coordinates,
self.get_structure_token_encoder(),
self.get_structure_encoder(),
structure_tokenizer=self.tokenizers.structure,
reference_sequence=input.sequence or "",
add_special_tokens=True,
@@ -517,8 +508,8 @@ class ESM3(nn.Module, ESM3InferenceClient):
return decode_protein_tensor(
input=input,
tokenizers=self.tokenizers,
structure_token_decoder=self.get_structure_token_decoder(),
function_token_decoder=self.get_function_token_decoder(),
structure_token_decoder=self.get_structure_decoder(),
function_token_decoder=self.get_function_decoder(),
)
def _forward(

View File

@@ -9,6 +9,7 @@ from esm.models.vqvae import (
StructureTokenDecoder,
StructureTokenEncoder,
)
from esm.tokenization import get_model_tokenizers
from esm.utils.constants.esm3 import data_root
from esm.utils.constants.models import (
ESM3_FUNCTION_DECODER_V0,
@@ -20,24 +21,6 @@ from esm.utils.constants.models import (
ModelBuilder = Callable[[torch.device | str], nn.Module]
def ESM3_sm_open_v0(device: torch.device | str = "cpu"):
with torch.device(device):
model = ESM3(
d_model=1536,
n_heads=24,
v_heads=256,
n_layers=48,
structure_encoder_name=ESM3_STRUCTURE_ENCODER_V0,
structure_decoder_name=ESM3_STRUCTURE_DECODER_V0,
function_decoder_name=ESM3_FUNCTION_DECODER_V0,
).eval()
state_dict = torch.load(
data_root() / "data/weights/esm3_sm_open_v1.pth", map_location=device
)
model.load_state_dict(state_dict)
return model
def ESM3_structure_encoder_v0(device: torch.device | str = "cpu"):
with torch.device(device):
model = StructureTokenEncoder(
@@ -70,6 +53,25 @@ def ESM3_function_decoder_v0(device: torch.device | str = "cpu"):
return model
def ESM3_sm_open_v0(device: torch.device | str = "cpu"):
with torch.device(device):
model = ESM3(
d_model=1536,
n_heads=24,
v_heads=256,
n_layers=48,
structure_encoder_fn=ESM3_structure_encoder_v0,
structure_decoder_fn=ESM3_structure_decoder_v0,
function_decoder_fn=ESM3_function_decoder_v0,
tokenizers=get_model_tokenizers(ESM3_OPEN_SMALL),
).eval()
state_dict = torch.load(
data_root() / "data/weights/esm3_sm_open_v1.pth", map_location=device
)
model.load_state_dict(state_dict)
return model
LOCAL_MODEL_REGISTRY: dict[str, ModelBuilder] = {
ESM3_OPEN_SMALL: ESM3_sm_open_v0,
ESM3_STRUCTURE_ENCODER_V0: ESM3_structure_encoder_v0,

View File

@@ -193,7 +193,9 @@ class GenerationConfig:
track: str = ""
invalid_ids: Sequence[int] = []
schedule: str = "cosine"
num_steps: int = 8
# Set this to a higher value for better generation results.
# Note that this needs to be less than or equal to the sequence length.
num_steps: int = 1
temperature: float = 1.0
top_p: float = 1.0
condition_on_coordinates_only: bool = True

View File

@@ -61,6 +61,11 @@ def decode_protein_tensor(
track_tokenizer = getattr(tokenizers, track.name)
if torch.all(tokens == track_tokenizer.pad_token_id):
setattr(input, track.name, None)
# If structure track has any mask tokens, do not decode.
if track.name == "structure" and torch.any(
tokens == track_tokenizer.mask_token_id
):
setattr(input, track.name, None)
if input.sequence is not None:
sequence = decode_sequence(input.sequence, tokenizers.sequence)

View File

@@ -104,7 +104,7 @@ def tokenize_structure(
structure_tokens = F.pad(
structure_tokens,
(left_pad, right_pad),
value=structure_tokenizer.pad_token_id,
value=structure_tokenizer.mask_token_id,
)
structure_tokens[0] = structure_tokenizer.bos_token_id
structure_tokens[-1] = structure_tokenizer.eos_token_id
@@ -186,7 +186,7 @@ def get_default_structure_tokens(
(sequence_length + 2,),
dtype=torch.int64,
)
* structure_tokenizer.pad_token_id
* structure_tokenizer.mask_token_id
)
# Always include BOS and EOS tokens
structure_tokens[0] = structure_tokenizer.bos_token_id

View File

@@ -136,6 +136,8 @@ def _make_masked_inputs(
if track == "coordinates":
dims = (sequence_length, 3, 3)
elif track == "confidence":
dims = (sequence_length,)
elif track == "attention_mask":
dims = (sequence_length,)
elif track == "function":
@@ -147,6 +149,9 @@ def _make_masked_inputs(
if track == "coordinates":
masked_tokens = torch.full(dims, torch.inf, dtype=torch.float)
elif track == "confidence":
# All-mask dummy input for confidence track.
masked_tokens = torch.full(dims, 0.0)
elif track == "attention_mask":
masked_tokens = torch.full(dims, 1, dtype=torch.bool)
else:
@@ -302,7 +307,7 @@ def iterative_sampling_tokens(
sequence_lengths = [len(tokens) for tokens in sampled_tokens]
# Figure out the number of tokens to be sampled for each prompt.
total_to_sample = []
for protein, seq_len, config in zip(input_tokens, sequence_lengths, configs):
for protein, seq_len, config in zip(sampled_tokens, sequence_lengths, configs):
track = config.track
if getattr(protein, track) is None:
@@ -324,7 +329,7 @@ def iterative_sampling_tokens(
# Now stack the list to make a single batched ESMProteinTensor.
batched_tokens = _stack_protein_tensors(
input_tokens,
sampled_tokens,
sequence_lengths,
tokenizers,
devices.pop(),
@@ -365,8 +370,8 @@ def iterative_sampling_tokens(
forward_out, i, keep_dim=True
)
# Trim logits to proper sequence length for this prompt.
per_prompt_forward_out.logits = _trim_sequence_tensor_dataclass(
per_prompt_forward_out.logits,
per_prompt_forward_out = _trim_sequence_tensor_dataclass(
per_prompt_forward_out,
# Note(jungong) : we can not smiply use sequence_lenths[i] here,
# what we want is for the sequence length of the logits to match
# that of the prompt, which may or may not be padded, depending on
@@ -568,7 +573,7 @@ def _sample_per_prompt(
)
mean_embedding = (
# [B, L, D] -> [B, D]
forward_output.embeddings[0].mean(dim=1) # type: ignore
forward_output.embeddings.mean(dim=1) # type: ignore
if sampling_config.return_mean_embedding
else None
)

View File

@@ -8,7 +8,7 @@ from esm.sdk.api import (
SamplingTrackConfig,
)
from esm.tokenization import (
TokenizerCollection,
TokenizerCollectionProtocol,
get_invalid_tokenizer_ids,
)
from esm.tokenization.function_tokenizer import (
@@ -102,7 +102,9 @@ class _BatchedESMProteinTensor(ESMProteinTensor):
s[i, ...] = v
def get_default_sampling_config(tokenizers: TokenizerCollection) -> SamplingConfig:
def get_default_sampling_config(
tokenizers: TokenizerCollectionProtocol,
) -> SamplingConfig:
tracks = [f.name for f in attr.fields(SamplingConfig)]
sampling_config = SamplingConfig()
for current_track in tracks:

View File

@@ -1,6 +1,6 @@
[project]
name = "esm"
version = "3.0.0"
version = "3.0.1"
description = "EvolutionaryScale open model repository"
readme = "README.md"
requires-python = ">=3.10"