mirror of
https://github.com/evolutionaryscale/esm.git
synced 2026-06-04 17:14:23 +08:00
update esmc tokenizer and return hidden states (#160)
Signed-off-by: tina-z-jia <145156075+tina-z-jia@users.noreply.github.com>
This commit is contained in:
@@ -1,2 +1,2 @@
|
||||
__version__ = "3.1.0"
|
||||
__version__ = "3.1.1"
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ class TransformerStack(nn.Module):
|
||||
affine: Affine3D | None = None,
|
||||
affine_mask: torch.Tensor | None = None,
|
||||
chain_id: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Forward pass of the TransformerStack.
|
||||
|
||||
@@ -85,6 +85,9 @@ class TransformerStack(nn.Module):
|
||||
*batch_dims, _ = x.shape
|
||||
if chain_id is None:
|
||||
chain_id = torch.ones(size=batch_dims, dtype=torch.int64, device=x.device)
|
||||
hiddens = []
|
||||
for block in self.blocks:
|
||||
x = block(x, sequence_id, affine, affine_mask, chain_id)
|
||||
return self.norm(x), x
|
||||
hiddens.append(x)
|
||||
hiddens = torch.stack(hiddens, dim=0)
|
||||
return self.norm(x), x, hiddens
|
||||
|
||||
@@ -376,7 +376,9 @@ class ESM3(nn.Module, ESM3InferenceClient):
|
||||
function_tokens,
|
||||
residue_annotation_tokens,
|
||||
)
|
||||
x, embedding = self.transformer(x, sequence_id, affine, affine_mask, chain_id)
|
||||
x, embedding, _ = self.transformer(
|
||||
x, sequence_id, affine, affine_mask, chain_id
|
||||
)
|
||||
return self.output_heads(x, embedding)
|
||||
|
||||
# The following methods are for the ESM3InferenceClient interface
|
||||
|
||||
@@ -21,6 +21,7 @@ from esm.tokenization import EsmSequenceTokenizer
|
||||
from esm.utils import encoding
|
||||
from esm.utils.constants.models import ESMC_600M
|
||||
from esm.utils.decoding import decode_sequence
|
||||
from esm.utils.misc import stack_variable_length_tensors
|
||||
from esm.utils.sampling import _BatchedESMProteinTensor
|
||||
|
||||
|
||||
@@ -28,6 +29,7 @@ from esm.utils.sampling import _BatchedESMProteinTensor
|
||||
class ESMCOutput:
|
||||
sequence_logits: torch.Tensor
|
||||
embeddings: torch.Tensor | None
|
||||
hidden_states: torch.Tensor | None
|
||||
|
||||
|
||||
class ESMC(nn.Module, ESMCInferenceClient):
|
||||
@@ -73,6 +75,23 @@ class ESMC(nn.Module, ESMCInferenceClient):
|
||||
def raw_model(self):
|
||||
return self
|
||||
|
||||
def _tokenize(self, sequence: list[str]) -> torch.Tensor:
|
||||
pad = self.tokenizer.pad_token_id
|
||||
assert pad is not None
|
||||
return stack_variable_length_tensors(
|
||||
[
|
||||
encoding.tokenize_sequence(x, self.tokenizer, add_special_tokens=True)
|
||||
for x in sequence
|
||||
],
|
||||
constant_value=pad,
|
||||
).to(next(self.parameters()).device)
|
||||
|
||||
def _detokenize(self, sequence: torch.Tensor) -> list[str]:
|
||||
pad = self.tokenizer.pad_token_id
|
||||
assert pad is not None
|
||||
assert sequence.ndim == 2
|
||||
return [decode_sequence(x[x != pad][1:-1], self.tokenizer) for x in sequence]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sequence_tokens: torch.Tensor | None = None,
|
||||
@@ -93,9 +112,11 @@ class ESMC(nn.Module, ESMCInferenceClient):
|
||||
sequence_id = sequence_tokens == self.tokenizer.pad_token_id
|
||||
|
||||
x = self.embed(sequence_tokens)
|
||||
x, _ = self.transformer(x, sequence_id=sequence_id)
|
||||
x, _, hiddens = self.transformer(x, sequence_id=sequence_id)
|
||||
sequence_logits = self.sequence_head(x)
|
||||
output = ESMCOutput(sequence_logits=sequence_logits, embeddings=x)
|
||||
output = ESMCOutput(
|
||||
sequence_logits=sequence_logits, embeddings=x, hidden_states=hiddens
|
||||
)
|
||||
return output
|
||||
|
||||
def encode(self, input: ESMProtein) -> ESMProteinTensor:
|
||||
@@ -103,9 +124,7 @@ class ESMC(nn.Module, ESMCInferenceClient):
|
||||
sequence_tokens = None
|
||||
|
||||
if input.sequence is not None:
|
||||
sequence_tokens = encoding.tokenize_sequence(
|
||||
input.sequence, self.tokenizer, add_special_tokens=True
|
||||
)
|
||||
sequence_tokens = self._tokenize([input.sequence])[0]
|
||||
return ESMProteinTensor(sequence=sequence_tokens).to(
|
||||
next(self.parameters()).device
|
||||
)
|
||||
@@ -114,7 +133,7 @@ class ESMC(nn.Module, ESMCInferenceClient):
|
||||
input = attr.evolve(input) # Make a copy
|
||||
|
||||
assert input.sequence is not None
|
||||
sequence = decode_sequence(input.sequence[1:-1], self.tokenizer)
|
||||
sequence = self._detokenize(input.sequence)[0]
|
||||
|
||||
return ESMProtein(sequence=sequence)
|
||||
|
||||
|
||||
@@ -172,7 +172,7 @@ class FunctionTokenDecoder(nn.Module):
|
||||
inputs = token_ids + vocab_offsets[None, :]
|
||||
|
||||
embed = self.embedding(inputs)
|
||||
encoding, _ = self.decoder(embed)
|
||||
encoding, _, _ = self.decoder(embed)
|
||||
pooled = torch.mean(encoding, dim=1)
|
||||
|
||||
return {name: head(pooled) for name, head in self.heads.items()}
|
||||
|
||||
@@ -250,7 +250,7 @@ class StructureTokenEncoder(nn.Module):
|
||||
|
||||
z = self.relative_positional_embedding(res_idxs[:, 0], res_idxs)
|
||||
|
||||
z, _ = self.transformer.forward(
|
||||
z, _, _ = self.transformer.forward(
|
||||
x=z,
|
||||
sequence_id=knn_sequence_id,
|
||||
affine=affine,
|
||||
@@ -397,7 +397,7 @@ class StructureTokenDecoder(nn.Module):
|
||||
|
||||
x = self.embed(structure_tokens)
|
||||
# !!! NOTE: Attention mask is actually unused here so watch out
|
||||
x, _ = self.decoder_stack.forward(
|
||||
x, _, _ = self.decoder_stack.forward(
|
||||
x, affine=None, affine_mask=None, sequence_id=sequence_id, chain_id=chain_id
|
||||
)
|
||||
|
||||
|
||||
@@ -10,7 +10,10 @@ from esm.models.vqvae import (
|
||||
StructureTokenDecoder,
|
||||
StructureTokenEncoder,
|
||||
)
|
||||
from esm.tokenization import get_model_tokenizers
|
||||
from esm.tokenization import (
|
||||
get_esm3_model_tokenizers,
|
||||
get_esmc_model_tokenizers,
|
||||
)
|
||||
from esm.utils.constants.esm3 import data_root
|
||||
from esm.utils.constants.models import (
|
||||
ESM3_FUNCTION_DECODER_V0,
|
||||
@@ -62,10 +65,7 @@ def ESM3_function_decoder_v0(device: torch.device | str = "cpu"):
|
||||
def ESMC_300M_202412(device: torch.device | str = "cpu"):
|
||||
with torch.device(device):
|
||||
model = ESMC(
|
||||
d_model=960,
|
||||
n_heads=15,
|
||||
n_layers=30,
|
||||
tokenizer=get_model_tokenizers(ESM3_OPEN_SMALL).sequence,
|
||||
d_model=960, n_heads=15, n_layers=30, tokenizer=get_esmc_model_tokenizers()
|
||||
).eval()
|
||||
state_dict = torch.load(
|
||||
data_root("esmc-300") / "data/weights/esmc_300m_2024_12_v0.pth",
|
||||
@@ -79,10 +79,7 @@ def ESMC_300M_202412(device: torch.device | str = "cpu"):
|
||||
def ESMC_600M_202412(device: torch.device | str = "cpu"):
|
||||
with torch.device(device):
|
||||
model = ESMC(
|
||||
d_model=1152,
|
||||
n_heads=18,
|
||||
n_layers=36,
|
||||
tokenizer=get_model_tokenizers(ESM3_OPEN_SMALL).sequence,
|
||||
d_model=1152, n_heads=18, n_layers=36, tokenizer=get_esmc_model_tokenizers()
|
||||
).eval()
|
||||
state_dict = torch.load(
|
||||
data_root("esmc-600") / "data/weights/esmc_600m_2024_12_v0.pth",
|
||||
@@ -103,7 +100,7 @@ def ESM3_sm_open_v0(device: torch.device | str = "cpu"):
|
||||
structure_encoder_fn=ESM3_structure_encoder_v0,
|
||||
structure_decoder_fn=ESM3_structure_decoder_v0,
|
||||
function_decoder_fn=ESM3_function_decoder_v0,
|
||||
tokenizers=get_model_tokenizers(ESM3_OPEN_SMALL),
|
||||
tokenizers=get_esm3_model_tokenizers(ESM3_OPEN_SMALL),
|
||||
).eval()
|
||||
state_dict = torch.load(
|
||||
data_root("esm3") / "data/weights/esm3_sm_open_v1.pth", map_location=device
|
||||
|
||||
@@ -10,7 +10,7 @@ from attr import asdict, define
|
||||
import esm.utils.constants.api as C
|
||||
from esm.tokenization import (
|
||||
TokenizerCollectionProtocol,
|
||||
get_model_tokenizers,
|
||||
get_esm3_model_tokenizers,
|
||||
)
|
||||
from esm.utils import encoding
|
||||
from esm.utils.constants.models import ESM3_OPEN_SMALL
|
||||
@@ -226,7 +226,7 @@ class ESMProteinTensor(ProteinType):
|
||||
device: torch.device | str = "cpu",
|
||||
) -> ESMProteinTensor:
|
||||
if tokenizers is None:
|
||||
tokenizers = get_model_tokenizers(ESM3_OPEN_SMALL)
|
||||
tokenizers = get_esm3_model_tokenizers(ESM3_OPEN_SMALL)
|
||||
|
||||
return ESMProteinTensor(
|
||||
sequence=encoding.get_default_sequence_tokens(
|
||||
|
||||
@@ -34,7 +34,7 @@ class TokenizerCollection:
|
||||
residue_annotations: ResidueAnnotationsTokenizer
|
||||
|
||||
|
||||
def get_model_tokenizers(model: str = ESM3_OPEN_SMALL) -> TokenizerCollection:
|
||||
def get_esm3_model_tokenizers(model: str = ESM3_OPEN_SMALL) -> TokenizerCollection:
|
||||
if normalize_model_name(model) == ESM3_OPEN_SMALL:
|
||||
return TokenizerCollection(
|
||||
sequence=EsmSequenceTokenizer(),
|
||||
@@ -48,6 +48,10 @@ def get_model_tokenizers(model: str = ESM3_OPEN_SMALL) -> TokenizerCollection:
|
||||
raise ValueError(f"Unknown model: {model}")
|
||||
|
||||
|
||||
def get_esmc_model_tokenizers() -> EsmSequenceTokenizer:
|
||||
return EsmSequenceTokenizer()
|
||||
|
||||
|
||||
def get_invalid_tokenizer_ids(tokenizer: EsmTokenizerBase) -> list[int]:
|
||||
if isinstance(tokenizer, EsmSequenceTokenizer):
|
||||
return [
|
||||
|
||||
@@ -10,12 +10,12 @@ from evolutionaryscale.utils.env import ModelName
|
||||
from evolutionaryscale.utils.remote_inference.api_v1 import (
|
||||
ESM3RemoteModelInferenceClient,
|
||||
)
|
||||
from projects.forge.fastapi.utils.model import _load_esm3
|
||||
from projects.forge.fastapi.utils.model import _load_esm_model
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def esm3_remote_inference_client():
|
||||
model = _load_esm3(ModelName.ESM3_TINY_DEV, distributed_model=False)
|
||||
model = _load_esm_model(ModelName.ESM3_TINY_DEV, distributed_model=False)
|
||||
client = ESM3RemoteModelInferenceClient(
|
||||
model,
|
||||
tokenizers=model.tokenizers,
|
||||
|
||||
@@ -1,34 +1,48 @@
|
||||
from esm.models.esmc import ESMC
|
||||
from examples.local_generate import get_sample_protein
|
||||
from esm.sdk.api import (
|
||||
ESMCInferenceClient,
|
||||
LogitsConfig,
|
||||
LogitsOutput,
|
||||
)
|
||||
from esm.sdk.api import ESMCInferenceClient, ESMProtein, LogitsConfig, LogitsOutput
|
||||
|
||||
|
||||
def main(client: ESMCInferenceClient):
|
||||
# ================================================================
|
||||
# Example usage: one single protein
|
||||
# ================================================================
|
||||
protein = get_sample_protein()
|
||||
protein.coordinates = None
|
||||
protein.function_annotations = None
|
||||
protein.sasa = None
|
||||
protein = ESMProtein(sequence="AAAAA")
|
||||
|
||||
# Use logits endpoint. Using bf16 for inference optimization
|
||||
protein_tensor = client.encode(protein)
|
||||
logits_output = client.logits(
|
||||
output = client.logits(
|
||||
protein_tensor, LogitsConfig(sequence=True, return_embeddings=True)
|
||||
)
|
||||
assert isinstance(
|
||||
logits_output, LogitsOutput
|
||||
), f"LogitsOutput was expected but got {logits_output}"
|
||||
assert (
|
||||
logits_output.logits is not None and logits_output.logits.sequence is not None
|
||||
output, LogitsOutput
|
||||
), f"LogitsOutput was expected but got {output}"
|
||||
assert output.logits is not None and output.logits.sequence is not None
|
||||
assert output.embeddings is not None and output.embeddings is not None
|
||||
print(
|
||||
f"Client returned logits with shape: {output.logits.sequence.shape} and embeddings with shape: {output.embeddings.shape}"
|
||||
)
|
||||
|
||||
|
||||
def raw_forward(model: ESMC):
|
||||
protein = ESMProtein(sequence="AAAAA")
|
||||
sequences = [protein.sequence, protein.sequence]
|
||||
|
||||
# ================================================================
|
||||
# Example usage: directly use the model
|
||||
# ================================================================
|
||||
input_ids = model._tokenize(sequences)
|
||||
output = model(input_ids)
|
||||
logits, embeddings, hiddens = (
|
||||
output.sequence_logits,
|
||||
output.embeddings,
|
||||
output.hidden_states,
|
||||
)
|
||||
print(
|
||||
f"Raw model returned logits with shape: {logits.shape}, embeddings with shape: {embeddings.shape} and hidden states with shape {hiddens.shape}"
|
||||
)
|
||||
assert logits_output.embeddings is not None and logits_output.embeddings is not None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(ESMC.from_pretrained("esmc_300m"))
|
||||
model = ESMC.from_pretrained("esmc_300m")
|
||||
main(model)
|
||||
raw_forward(model)
|
||||
|
||||
@@ -9,7 +9,7 @@ from esm.pretrained import (
|
||||
ESM3_structure_decoder_v0,
|
||||
ESM3_structure_encoder_v0,
|
||||
)
|
||||
from esm.tokenization import get_model_tokenizers
|
||||
from esm.tokenization import get_esm3_model_tokenizers
|
||||
from esm.tokenization.function_tokenizer import (
|
||||
InterProQuantizedTokenizer as EsmFunctionTokenizer,
|
||||
)
|
||||
@@ -50,7 +50,7 @@ def inverse_folding_example():
|
||||
|
||||
@torch.no_grad()
|
||||
def conditioned_prediction_example():
|
||||
tokenizers = get_model_tokenizers()
|
||||
tokenizers = get_esm3_model_tokenizers()
|
||||
|
||||
model = ESM3_sm_open_v0("cuda")
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "esm"
|
||||
version = "3.1.0"
|
||||
version = "3.1.1"
|
||||
description = "EvolutionaryScale open model repository"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
@@ -24,7 +24,7 @@ dependencies = [
|
||||
"torch>=2.2.0",
|
||||
"torchvision",
|
||||
"torchtext",
|
||||
"transformers",
|
||||
"transformers<4.47.0",
|
||||
"ipython",
|
||||
"einops",
|
||||
"biotite==0.41.2",
|
||||
|
||||
Reference in New Issue
Block a user