update esmc tokenizer and return hidden states (#160)

Signed-off-by: tina-z-jia <145156075+tina-z-jia@users.noreply.github.com>
This commit is contained in:
tina-z-jia
2024-12-06 11:29:31 -08:00
committed by GitHub
parent 5604523746
commit 8127b99068
13 changed files with 88 additions and 49 deletions

View File

@@ -1,2 +1,2 @@
__version__ = "3.1.0"
__version__ = "3.1.1"

View File

@@ -66,7 +66,7 @@ class TransformerStack(nn.Module):
affine: Affine3D | None = None,
affine_mask: torch.Tensor | None = None,
chain_id: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward pass of the TransformerStack.
@@ -85,6 +85,9 @@ class TransformerStack(nn.Module):
*batch_dims, _ = x.shape
if chain_id is None:
chain_id = torch.ones(size=batch_dims, dtype=torch.int64, device=x.device)
hiddens = []
for block in self.blocks:
x = block(x, sequence_id, affine, affine_mask, chain_id)
return self.norm(x), x
hiddens.append(x)
hiddens = torch.stack(hiddens, dim=0)
return self.norm(x), x, hiddens

View File

@@ -376,7 +376,9 @@ class ESM3(nn.Module, ESM3InferenceClient):
function_tokens,
residue_annotation_tokens,
)
x, embedding = self.transformer(x, sequence_id, affine, affine_mask, chain_id)
x, embedding, _ = self.transformer(
x, sequence_id, affine, affine_mask, chain_id
)
return self.output_heads(x, embedding)
# The following methods are for the ESM3InferenceClient interface

View File

@@ -21,6 +21,7 @@ from esm.tokenization import EsmSequenceTokenizer
from esm.utils import encoding
from esm.utils.constants.models import ESMC_600M
from esm.utils.decoding import decode_sequence
from esm.utils.misc import stack_variable_length_tensors
from esm.utils.sampling import _BatchedESMProteinTensor
@@ -28,6 +29,7 @@ from esm.utils.sampling import _BatchedESMProteinTensor
class ESMCOutput:
sequence_logits: torch.Tensor
embeddings: torch.Tensor | None
hidden_states: torch.Tensor | None
class ESMC(nn.Module, ESMCInferenceClient):
@@ -73,6 +75,23 @@ class ESMC(nn.Module, ESMCInferenceClient):
def raw_model(self):
return self
def _tokenize(self, sequence: list[str]) -> torch.Tensor:
pad = self.tokenizer.pad_token_id
assert pad is not None
return stack_variable_length_tensors(
[
encoding.tokenize_sequence(x, self.tokenizer, add_special_tokens=True)
for x in sequence
],
constant_value=pad,
).to(next(self.parameters()).device)
def _detokenize(self, sequence: torch.Tensor) -> list[str]:
pad = self.tokenizer.pad_token_id
assert pad is not None
assert sequence.ndim == 2
return [decode_sequence(x[x != pad][1:-1], self.tokenizer) for x in sequence]
def forward(
self,
sequence_tokens: torch.Tensor | None = None,
@@ -93,9 +112,11 @@ class ESMC(nn.Module, ESMCInferenceClient):
sequence_id = sequence_tokens == self.tokenizer.pad_token_id
x = self.embed(sequence_tokens)
x, _ = self.transformer(x, sequence_id=sequence_id)
x, _, hiddens = self.transformer(x, sequence_id=sequence_id)
sequence_logits = self.sequence_head(x)
output = ESMCOutput(sequence_logits=sequence_logits, embeddings=x)
output = ESMCOutput(
sequence_logits=sequence_logits, embeddings=x, hidden_states=hiddens
)
return output
def encode(self, input: ESMProtein) -> ESMProteinTensor:
@@ -103,9 +124,7 @@ class ESMC(nn.Module, ESMCInferenceClient):
sequence_tokens = None
if input.sequence is not None:
sequence_tokens = encoding.tokenize_sequence(
input.sequence, self.tokenizer, add_special_tokens=True
)
sequence_tokens = self._tokenize([input.sequence])[0]
return ESMProteinTensor(sequence=sequence_tokens).to(
next(self.parameters()).device
)
@@ -114,7 +133,7 @@ class ESMC(nn.Module, ESMCInferenceClient):
input = attr.evolve(input) # Make a copy
assert input.sequence is not None
sequence = decode_sequence(input.sequence[1:-1], self.tokenizer)
sequence = self._detokenize(input.sequence)[0]
return ESMProtein(sequence=sequence)

View File

@@ -172,7 +172,7 @@ class FunctionTokenDecoder(nn.Module):
inputs = token_ids + vocab_offsets[None, :]
embed = self.embedding(inputs)
encoding, _ = self.decoder(embed)
encoding, _, _ = self.decoder(embed)
pooled = torch.mean(encoding, dim=1)
return {name: head(pooled) for name, head in self.heads.items()}

View File

@@ -250,7 +250,7 @@ class StructureTokenEncoder(nn.Module):
z = self.relative_positional_embedding(res_idxs[:, 0], res_idxs)
z, _ = self.transformer.forward(
z, _, _ = self.transformer.forward(
x=z,
sequence_id=knn_sequence_id,
affine=affine,
@@ -397,7 +397,7 @@ class StructureTokenDecoder(nn.Module):
x = self.embed(structure_tokens)
# !!! NOTE: Attention mask is actually unused here so watch out
x, _ = self.decoder_stack.forward(
x, _, _ = self.decoder_stack.forward(
x, affine=None, affine_mask=None, sequence_id=sequence_id, chain_id=chain_id
)

View File

@@ -10,7 +10,10 @@ from esm.models.vqvae import (
StructureTokenDecoder,
StructureTokenEncoder,
)
from esm.tokenization import get_model_tokenizers
from esm.tokenization import (
get_esm3_model_tokenizers,
get_esmc_model_tokenizers,
)
from esm.utils.constants.esm3 import data_root
from esm.utils.constants.models import (
ESM3_FUNCTION_DECODER_V0,
@@ -62,10 +65,7 @@ def ESM3_function_decoder_v0(device: torch.device | str = "cpu"):
def ESMC_300M_202412(device: torch.device | str = "cpu"):
with torch.device(device):
model = ESMC(
d_model=960,
n_heads=15,
n_layers=30,
tokenizer=get_model_tokenizers(ESM3_OPEN_SMALL).sequence,
d_model=960, n_heads=15, n_layers=30, tokenizer=get_esmc_model_tokenizers()
).eval()
state_dict = torch.load(
data_root("esmc-300") / "data/weights/esmc_300m_2024_12_v0.pth",
@@ -79,10 +79,7 @@ def ESMC_300M_202412(device: torch.device | str = "cpu"):
def ESMC_600M_202412(device: torch.device | str = "cpu"):
with torch.device(device):
model = ESMC(
d_model=1152,
n_heads=18,
n_layers=36,
tokenizer=get_model_tokenizers(ESM3_OPEN_SMALL).sequence,
d_model=1152, n_heads=18, n_layers=36, tokenizer=get_esmc_model_tokenizers()
).eval()
state_dict = torch.load(
data_root("esmc-600") / "data/weights/esmc_600m_2024_12_v0.pth",
@@ -103,7 +100,7 @@ def ESM3_sm_open_v0(device: torch.device | str = "cpu"):
structure_encoder_fn=ESM3_structure_encoder_v0,
structure_decoder_fn=ESM3_structure_decoder_v0,
function_decoder_fn=ESM3_function_decoder_v0,
tokenizers=get_model_tokenizers(ESM3_OPEN_SMALL),
tokenizers=get_esm3_model_tokenizers(ESM3_OPEN_SMALL),
).eval()
state_dict = torch.load(
data_root("esm3") / "data/weights/esm3_sm_open_v1.pth", map_location=device

View File

@@ -10,7 +10,7 @@ from attr import asdict, define
import esm.utils.constants.api as C
from esm.tokenization import (
TokenizerCollectionProtocol,
get_model_tokenizers,
get_esm3_model_tokenizers,
)
from esm.utils import encoding
from esm.utils.constants.models import ESM3_OPEN_SMALL
@@ -226,7 +226,7 @@ class ESMProteinTensor(ProteinType):
device: torch.device | str = "cpu",
) -> ESMProteinTensor:
if tokenizers is None:
tokenizers = get_model_tokenizers(ESM3_OPEN_SMALL)
tokenizers = get_esm3_model_tokenizers(ESM3_OPEN_SMALL)
return ESMProteinTensor(
sequence=encoding.get_default_sequence_tokens(

View File

@@ -34,7 +34,7 @@ class TokenizerCollection:
residue_annotations: ResidueAnnotationsTokenizer
def get_model_tokenizers(model: str = ESM3_OPEN_SMALL) -> TokenizerCollection:
def get_esm3_model_tokenizers(model: str = ESM3_OPEN_SMALL) -> TokenizerCollection:
if normalize_model_name(model) == ESM3_OPEN_SMALL:
return TokenizerCollection(
sequence=EsmSequenceTokenizer(),
@@ -48,6 +48,10 @@ def get_model_tokenizers(model: str = ESM3_OPEN_SMALL) -> TokenizerCollection:
raise ValueError(f"Unknown model: {model}")
def get_esmc_model_tokenizers() -> EsmSequenceTokenizer:
return EsmSequenceTokenizer()
def get_invalid_tokenizer_ids(tokenizer: EsmTokenizerBase) -> list[int]:
if isinstance(tokenizer, EsmSequenceTokenizer):
return [

View File

@@ -10,12 +10,12 @@ from evolutionaryscale.utils.env import ModelName
from evolutionaryscale.utils.remote_inference.api_v1 import (
ESM3RemoteModelInferenceClient,
)
from projects.forge.fastapi.utils.model import _load_esm3
from projects.forge.fastapi.utils.model import _load_esm_model
@pytest.fixture()
def esm3_remote_inference_client():
model = _load_esm3(ModelName.ESM3_TINY_DEV, distributed_model=False)
model = _load_esm_model(ModelName.ESM3_TINY_DEV, distributed_model=False)
client = ESM3RemoteModelInferenceClient(
model,
tokenizers=model.tokenizers,

View File

@@ -1,34 +1,48 @@
from esm.models.esmc import ESMC
from examples.local_generate import get_sample_protein
from esm.sdk.api import (
ESMCInferenceClient,
LogitsConfig,
LogitsOutput,
)
from esm.sdk.api import ESMCInferenceClient, ESMProtein, LogitsConfig, LogitsOutput
def main(client: ESMCInferenceClient):
# ================================================================
# Example usage: one single protein
# ================================================================
protein = get_sample_protein()
protein.coordinates = None
protein.function_annotations = None
protein.sasa = None
protein = ESMProtein(sequence="AAAAA")
# Use logits endpoint. Using bf16 for inference optimization
protein_tensor = client.encode(protein)
logits_output = client.logits(
output = client.logits(
protein_tensor, LogitsConfig(sequence=True, return_embeddings=True)
)
assert isinstance(
logits_output, LogitsOutput
), f"LogitsOutput was expected but got {logits_output}"
assert (
logits_output.logits is not None and logits_output.logits.sequence is not None
output, LogitsOutput
), f"LogitsOutput was expected but got {output}"
assert output.logits is not None and output.logits.sequence is not None
assert output.embeddings is not None and output.embeddings is not None
print(
f"Client returned logits with shape: {output.logits.sequence.shape} and embeddings with shape: {output.embeddings.shape}"
)
def raw_forward(model: ESMC):
protein = ESMProtein(sequence="AAAAA")
sequences = [protein.sequence, protein.sequence]
# ================================================================
# Example usage: directly use the model
# ================================================================
input_ids = model._tokenize(sequences)
output = model(input_ids)
logits, embeddings, hiddens = (
output.sequence_logits,
output.embeddings,
output.hidden_states,
)
print(
f"Raw model returned logits with shape: {logits.shape}, embeddings with shape: {embeddings.shape} and hidden states with shape {hiddens.shape}"
)
assert logits_output.embeddings is not None and logits_output.embeddings is not None
if __name__ == "__main__":
main(ESMC.from_pretrained("esmc_300m"))
model = ESMC.from_pretrained("esmc_300m")
main(model)
raw_forward(model)

View File

@@ -9,7 +9,7 @@ from esm.pretrained import (
ESM3_structure_decoder_v0,
ESM3_structure_encoder_v0,
)
from esm.tokenization import get_model_tokenizers
from esm.tokenization import get_esm3_model_tokenizers
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer as EsmFunctionTokenizer,
)
@@ -50,7 +50,7 @@ def inverse_folding_example():
@torch.no_grad()
def conditioned_prediction_example():
tokenizers = get_model_tokenizers()
tokenizers = get_esm3_model_tokenizers()
model = ESM3_sm_open_v0("cuda")

View File

@@ -1,6 +1,6 @@
[project]
name = "esm"
version = "3.1.0"
version = "3.1.1"
description = "EvolutionaryScale open model repository"
readme = "README.md"
requires-python = ">=3.10"
@@ -24,7 +24,7 @@ dependencies = [
"torch>=2.2.0",
"torchvision",
"torchtext",
"transformers",
"transformers<4.47.0",
"ipython",
"einops",
"biotite==0.41.2",