This commit is contained in:
Jun Gong
2024-07-12 19:56:47 +00:00
parent b16e190d6c
commit 95e3c5be8a
28 changed files with 2226 additions and 546 deletions

View File

@@ -0,0 +1 @@
__version__ = "0.2rc1"

View File

@@ -59,12 +59,19 @@ class MultiHeadAttention(nn.Module):
reshaper, (query_BLD, key_BLD, value_BLD)
)
# Where True, enable participation in attention.
mask_BLL = seq_id.unsqueeze(-1) == seq_id.unsqueeze(-2)
mask_BHLL = mask_BLL.unsqueeze(1)
if seq_id is not None:
# Where True, enable participation in attention.
mask_BLL = seq_id.unsqueeze(-1) == seq_id.unsqueeze(-2)
mask_BHLL = mask_BLL.unsqueeze(1)
context_BHLD = F.scaled_dot_product_attention(
query_BHLD, key_BHLD, value_BHLD, mask_BHLL
)
context_BHLD = F.scaled_dot_product_attention(
query_BHLD, key_BHLD, value_BHLD, mask_BHLL
)
else:
# Shortcut, if we don't use attention biases then torch
# will autoselect flashattention as the implementation
context_BHLD = F.scaled_dot_product_attention(
query_BHLD, key_BHLD, value_BHLD
)
context_BLD = einops.rearrange(context_BHLD, "b h s d -> b s (h d)")
return self.out_proj(context_BLD)

View File

@@ -50,6 +50,8 @@ class GeometricReasoningOriginalImpl(nn.Module):
self.rotation_scale_per_head = nn.Parameter(torch.zeros((self.v_heads)))
def forward(self, s, affine, affine_mask, sequence_id, chain_id):
if sequence_id is None:
sequence_id = torch.zeros_like(s[..., 0], dtype=torch.int64)
attn_bias = sequence_id.unsqueeze(-1) == sequence_id.unsqueeze(-2)
attn_bias = attn_bias.unsqueeze(1).float()
attn_bias = attn_bias.masked_fill(

View File

@@ -83,10 +83,6 @@ class TransformerStack(nn.Module):
pre_norm: The embedding of shape (batch_size, sequence_length, d_model).
"""
*batch_dims, _ = x.shape
if sequence_id is None:
sequence_id = torch.ones(
size=batch_dims, dtype=torch.int64, device=x.device
)
if chain_id is None:
chain_id = torch.ones(size=batch_dims, dtype=torch.int64, device=x.device)
for block in self.blocks:

View File

@@ -26,9 +26,7 @@ from esm.sdk.api import (
ForwardTrackData,
GenerationConfig,
ProteinType,
ReturnLogitsConfig,
SamplingConfig,
SamplingTrackConfig,
)
from esm.tokenization import get_model_tokenizers
from esm.utils import encoding
@@ -36,15 +34,16 @@ from esm.utils.constants import esm3 as C
from esm.utils.constants.models import ESM3_OPEN_SMALL
from esm.utils.decoding import decode_protein_tensor
from esm.utils.generation import (
_batch_forward,
_sample_per_prompt,
_slice_tensor_dataclass,
iterative_sampling_raw,
iterative_sampling_tokens,
)
from esm.utils.misc import rbf
from esm.utils.sampling import (
_BatchedESMProteinTensor,
get_default_sampling_config,
sample_function_logits,
sample_logits,
sample_residue_annotation_logits,
)
from esm.utils.structure.affine3d import (
build_affine3d_from_coordinates,
@@ -222,9 +221,9 @@ class ESM3(nn.Module, ESM3InferenceClient):
self.structure_decoder_name = structure_decoder_name
self.function_decoder_name = function_decoder_name
self.structure_encoder: StructureTokenEncoder | None = None # type: ignore
self.structure_decoder: StructureTokenDecoder | None = None # type: ignore
self.function_decoder: FunctionTokenDecoder | None = None # type: ignore
self.structure_encoder: StructureTokenEncoder | None = None
self.structure_decoder: StructureTokenDecoder | None = None
self.function_decoder: FunctionTokenDecoder | None = None
self.tokenizers = get_model_tokenizers(ESM3_OPEN_SMALL)
@@ -232,29 +231,40 @@ class ESM3(nn.Module, ESM3InferenceClient):
def from_pretrained(
cls,
model_name: str = ESM3_OPEN_SMALL,
device: torch.device | str = "cpu",
device: torch.device | None = None,
) -> ESM3:
from esm.pretrained import load_local_model
if model_name not in [ESM3_OPEN_SMALL]:
raise ValueError(f"Model name {model_name} is not a valid ESM3 model name.")
model: ESM3 = load_local_model(model_name, device=device) # type: ignore
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_local_model(model_name, device=device)
if device.type != "cpu":
model = model.to(torch.bfloat16)
assert isinstance(model, ESM3)
return model
def get_structure_token_encoder(self) -> StructureTokenEncoder:
if self.structure_encoder is None:
self.structure_encoder = self.load_model(self.structure_encoder_name) # type: ignore
return self.structure_encoder # type: ignore
model = self.load_model(self.structure_encoder_name)
assert isinstance(model, StructureTokenEncoder)
self.structure_encoder = model
return self.structure_encoder
def get_structure_token_decoder(self) -> StructureTokenDecoder:
if self.structure_decoder is None:
self.structure_decoder = self.load_model(self.structure_decoder_name) # type: ignore
return self.structure_decoder # type: ignore
model = self.load_model(self.structure_decoder_name)
assert isinstance(model, StructureTokenDecoder)
self.structure_decoder = model
return self.structure_decoder
def get_function_token_decoder(self) -> FunctionTokenDecoder:
if self.function_decoder is None:
self.function_decoder = self.load_model(self.function_decoder_name) # type: ignore
return self.function_decoder # type: ignore
model = self.load_model(self.function_decoder_name)
assert isinstance(model, FunctionTokenDecoder)
self.function_decoder = model
return self.function_decoder
def load_model(self, model_name: str):
# Lazy import from pretrained
@@ -324,12 +334,11 @@ class ESM3(nn.Module, ESM3InferenceClient):
torch.full((1, L), tok, dtype=torch.long, device=device) if x is None else x
)
sequence_tokens = defaults(sequence_tokens, t.sequence.mask_token_id)
ss8_tokens = defaults(ss8_tokens, C.SS8_UNK_TOKEN)
sasa_tokens = defaults(sasa_tokens, C.SASA_UNK_TOKEN)
ss8_tokens = defaults(ss8_tokens, C.SS8_PAD_TOKEN)
sasa_tokens = defaults(sasa_tokens, C.SASA_PAD_TOKEN)
average_plddt = defaults(average_plddt, 1).float()
per_res_plddt = defaults(per_res_plddt, 0).float()
chain_id = defaults(chain_id, 0)
sequence_id = defaults(sequence_id, 0)
if residue_annotation_tokens is None:
residue_annotation_tokens = torch.full(
@@ -384,10 +393,39 @@ class ESM3(nn.Module, ESM3InferenceClient):
# The following methods are for the ESM3InferenceClient interface
def generate(self, input: ProteinType, config: GenerationConfig) -> ProteinType:
if isinstance(input, ESMProtein):
return iterative_sampling_raw(self, input, config)
elif isinstance(input, ESMProteinTensor):
return iterative_sampling_tokens(self, input, config, self.tokenizers)
"""Wrap around batched generation."""
proteins = self.batch_generate([input], [config])
assert len(proteins) == 1
return proteins[0]
def batch_generate(
self, inputs: list[ProteinType], configs: list[GenerationConfig]
) -> list[ProteinType]:
assert len(inputs) == len(
configs
), "Must have the same number of prompts and configs."
if inputs is []:
# Nothing to do.
return []
# Make sure prompts are of the same type.
t = type(inputs[0])
for i in range(1, len(inputs)):
assert isinstance(inputs[i], t), (
"Prompts must have the same type. Got "
f"{t.__name__ and type(inputs[i]).__name__} instead."
)
if isinstance(inputs[0], ESMProtein):
return iterative_sampling_raw(self, inputs, configs) # type: ignore
elif isinstance(inputs[0], ESMProteinTensor):
return iterative_sampling_tokens(
self,
inputs, # type: ignore
configs,
self.tokenizers, # type: ignore
)
else:
raise ValueError("Input must be an ESMProtein or ESMProteinTensor")
@@ -486,6 +524,7 @@ class ESM3(nn.Module, ESM3InferenceClient):
def _forward(
self, input: ESMProteinTensor, config: ForwardConfig = ForwardConfig()
) -> ForwardOutput:
device = torch.device(input.device)
# Default plddt conditioning for inference. 1s where coordinates are provided.
if input.coordinates is None:
per_res_plddt = None
@@ -493,7 +532,12 @@ class ESM3(nn.Module, ESM3InferenceClient):
# 1.0 if all coordinates at specific indices have valid non-nan values.
per_res_plddt = input.coordinates.isfinite().all(dim=-1).any(dim=-1).float()
with torch.no_grad() if self.eval else contextlib.nullcontext():
with (
torch.no_grad(), # Assume no gradients for now...
torch.autocast(enabled=True, device_type=device.type, dtype=torch.bfloat16) # type: ignore
if device.type == "cuda"
else contextlib.nullcontext(),
):
output = self.forward(
sequence_tokens=input.sequence,
structure_tokens=input.structure,
@@ -508,31 +552,32 @@ class ESM3(nn.Module, ESM3InferenceClient):
sequence_id=None,
)
if config.return_logits:
logits = ForwardTrackData(
sequence=output.sequence_logits,
structure=output.structure_logits,
secondary_structure=output.secondary_structure_logits,
sasa=output.sasa_logits,
function=output.function_logits,
)
else:
logits = None
output = ESMOutput(
**{k: v.to(device).to(torch.float32) for k, v in vars(output).items()}
)
return ForwardOutput(
logits=logits,
residue_annotation_logits=output.residue_logits,
embeddings=output.embeddings if config.return_embeddings else None,
if config.return_logits:
logits = ForwardTrackData(
sequence=output.sequence_logits,
structure=output.structure_logits,
secondary_structure=output.secondary_structure_logits,
sasa=output.sasa_logits,
function=output.function_logits,
)
else:
logits = None
return ForwardOutput(
logits=logits,
residue_annotation_logits=output.residue_logits,
embeddings=output.embeddings if config.return_embeddings else None,
)
def forward_and_sample(
self, input: ESMProteinTensor, sampling_configuration: SamplingConfig
) -> ForwardAndSampleOutput:
protein_tensor = attr.evolve(input) # Make a copy
def maybe_clone(x: torch.Tensor | None) -> torch.Tensor | None:
return x.clone() if x is not None else None
device = next(self.parameters()).device
sampling_config = sampling_configuration
@@ -551,249 +596,20 @@ class ESM3(nn.Module, ESM3InferenceClient):
getattr(default_protein_tensor, track.name, None),
)
# Preprocessing
sequence_length: int = -1
for track in [
"sequence",
"structure",
"secondary_structure",
"sasa",
"function",
"residue_annotations",
]:
input_tensor: torch.Tensor | None = getattr(protein_tensor, track, None)
if input_tensor is not None:
# Add batch dimension if necessary
if track in ["sequence", "structure", "secondary_structure", "sasa"]:
if len(input_tensor.size()) == 1:
input_tensor = input_tensor.unsqueeze(0) # (L,) -> (1, L)
elif track in ["function", "residue_annotations"]:
if len(input_tensor.size()) == 2:
input_tensor = input_tensor.unsqueeze(0) # (L, O) -> (1, L, O)
# Check length consistency
if sequence_length == -1:
sequence_length = input_tensor.size(1)
else:
if input_tensor.size(1) != sequence_length:
raise ValueError(
f"Length mismatch for track {track}. Expected {sequence_length}, got {input_tensor.size(1)}"
)
# Move input tensor to model device
input_tensor = input_tensor.to(device)
setattr(protein_tensor, track, input_tensor)
if protein_tensor.coordinates is not None:
coordinates = protein_tensor.coordinates
if len(coordinates.size()) == 3:
coordinates = coordinates.unsqueeze(0)
protein_tensor.coordinates = coordinates.to(device)
sequence_length = coordinates.size(1)
if sequence_length == -1:
if len(protein_tensor) <= 0:
raise ValueError("No input data provided")
# Forward pass
forward_output = self._forward(
protein_tensor,
ForwardConfig(
ReturnLogitsConfig(
sequence=True,
structure=True,
secondary_structure=True,
sasa=True,
function=True,
residue_annotations=True,
),
return_embeddings=True,
),
# Move input protein to proper device.
batched_protein = _BatchedESMProteinTensor.from_protein_tensor(protein_tensor)
batched_protein.to(device)
forward_output: ForwardOutput = _batch_forward(self, batched_protein)
forward_and_sample_out: ForwardAndSampleOutput = _sample_per_prompt(
batched_protein,
forward_output,
sampling_config,
self.tokenizers,
)
# Sampling
tokens_dir = {}
track_sampling_metadata_dir: dict[str, dict | None] = {}
for track in ["sequence", "structure", "secondary_structure", "sasa"]:
config = getattr(sampling_config, track)
if config is None:
tokens_dir[track] = maybe_clone(getattr(input, track))
continue
sampling_metadata = self._sample_track(
logits=getattr(forward_output.logits, track)[0, ...],
tokens=getattr(protein_tensor, track)[0, ...],
sampling_track_config=config,
mask_idx=getattr(self.tokenizers, track).mask_token_id,
)
tokens_dir[track] = sampling_metadata.pop("sampled_tokens") # (L,)
track_sampling_metadata_dir[track] = sampling_metadata
# Sample function and residue annotations separately
config = getattr(sampling_config, "function")
if config is None:
tokens_dir["function"] = maybe_clone(getattr(input, "function"))
tokens_dir["residue_annotations"] = maybe_clone(
getattr(input, "residue_annotations")
)
else:
sampling_metadata = self._sample_function_track(
tokens=getattr(protein_tensor, "function")[0, ...],
logits=getattr(forward_output.logits, "function")[0, ...],
sampling_track_config=config,
)
tokens_dir["function"] = sampling_metadata.pop("sampled_tokens") # (L, D)
track_sampling_metadata_dir["function"] = sampling_metadata
sampled_tokens, _ = sample_residue_annotation_logits(
logits=forward_output.residue_annotation_logits[0, ...] # type: ignore
)
tokens_dir["residue_annotations"] = sampled_tokens # (L, MAX_R)
# Format output
forward_and_sample_output_dir = {}
forward_and_sample_output_dir["protein_tensor"] = ESMProteinTensor(**tokens_dir)
for property in [
"entropy",
"prob",
"logprob",
"top_prob",
"topk_logprob",
"topk_tokens",
]:
is_all_none = True
forward_track_data_dir = {}
for track in track_sampling_metadata_dir.keys():
values = track_sampling_metadata_dir[track]
if values is not None and values.get(property, None) is not None:
forward_track_data_dir[track] = values.get(property, None)
is_all_none = False
if not is_all_none:
forward_and_sample_output_dir[property] = ForwardTrackData(
**forward_track_data_dir
)
else:
forward_and_sample_output_dir[property] = None
perres_embed = (
forward_output.embeddings[0] # type: ignore
if sampling_configuration.return_per_residue_embeddings
else None
)
mean_embedding = (
forward_output.embeddings[0].mean(0) # type: ignore
if sampling_configuration.return_mean_embedding
else None
)
return ForwardAndSampleOutput(
per_residue_embedding=perres_embed,
mean_embedding=mean_embedding,
**forward_and_sample_output_dir,
)
def _sample_track(
self,
logits: torch.Tensor,
tokens: torch.Tensor,
sampling_track_config: SamplingTrackConfig,
mask_idx: int,
) -> dict[str, torch.Tensor]:
# Sample in all positions
temperature = sampling_track_config.temperature
sampled_tokens = sample_logits(
logits, temperature=temperature, top_p=sampling_track_config.top_p
)
log_probs = logits.log_softmax(-1)
# Do not sample at BOS and EOS tokens
sampling_mask = torch.ones_like(tokens, dtype=torch.bool) # (L, )
sampling_mask[0] = False
sampling_mask[-1] = False
# Do not sample at special token positions but allow sampling at mask token
special_minus_mask = list(set(sampling_track_config.invalid_ids) - {mask_idx})
if len(special_minus_mask) > 0:
special_tokens = torch.tensor(special_minus_mask, device=tokens.device)
assert special_tokens.numel() > 0
sampling_mask = sampling_mask & (
tokens[..., None] != special_tokens[None, :]
).all(-1)
# Keep only samples from masked positions (if specified)
if sampling_track_config.only_sample_masked_tokens:
masked_tokens = tokens == mask_idx
sampling_mask = sampling_mask & masked_tokens
sampled_tokens = torch.where(sampling_mask, sampled_tokens, tokens)
return self._compute_track_metadata(
sampled_tokens,
log_probs,
sampling_mask,
top_k=sampling_track_config.topk_logprobs,
)
def _sample_function_track(
self,
tokens: torch.Tensor,
logits: torch.Tensor,
sampling_track_config: SamplingTrackConfig,
) -> dict[str, torch.Tensor]:
# Do not sample at BOS and EOS tokens
sampling_mask = torch.ones_like(tokens, dtype=torch.bool)
sampling_mask[0] = False
sampling_mask[-1] = False
sampled_tokens, probs = sample_function_logits(
logits,
self.tokenizers.function,
top_p=sampling_track_config.top_p,
temperature=sampling_track_config.temperature,
)
if sampling_track_config.only_sample_masked_tokens:
raise ValueError(
"Sampling only masked tokens is undefined for function tokens."
)
sampled_tokens = torch.where(sampling_mask, sampled_tokens, tokens) # (L, D)
return self._compute_track_metadata(
sampled_tokens,
probs,
sampling_mask,
top_k=sampling_track_config.topk_logprobs,
)
@staticmethod
def _compute_track_metadata(
sampled_tokens: torch.Tensor,
log_probs: torch.Tensor,
sampling_mask: torch.Tensor,
top_k: int,
) -> dict:
probs = torch.exp(log_probs) # (B, L)
entropy = torch.distributions.Categorical(probs=probs).entropy() # (B, L)
# Only compute probabilities for sampled tokens
sampled_logprob = torch.zeros_like(
sampled_tokens, dtype=torch.float32
) # (B, L)
sampled_tokens_valid = sampled_tokens[sampling_mask]
sampled_log_probs_valid = log_probs[sampling_mask, sampled_tokens_valid]
sampled_logprob[sampling_mask] = sampled_log_probs_valid
# Calculate extra metadata
sampled_prob = torch.exp(sampled_logprob)
top_prob = torch.max(probs, dim=-1).values
topk_logprobs, topk_tokens = torch.topk(log_probs, top_k, dim=-1)
topk_logprobs = None if top_k == 0 else topk_logprobs
topk_tokens = None if top_k == 0 else topk_tokens
return {
"entropy": entropy,
"sampled_tokens": sampled_tokens,
"prob": sampled_prob,
"logprob": sampled_logprob,
"top_prob": top_prob,
"topk_logprob": topk_logprobs,
"topk_tokens": topk_tokens,
}
# There is only 1 prompt to sample for.
return _slice_tensor_dataclass(forward_and_sample_out, 0)

View File

@@ -8,6 +8,7 @@ import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from cloudpathlib import AnyPath
from esm.layers.regression_head import RegressionHead
from esm.layers.transformer_stack import TransformerStack
@@ -70,7 +71,7 @@ class FunctionTokenDecoder(nn.Module):
}
assert len(self.interpro_ids) == config.num_interpro_classes
with open(config.keyword_vocabulary_path, "r") as f:
with AnyPath(config.keyword_vocabulary_path).open("r") as f:
self.keywords_vocabulary: list[str] = list(f.read().strip().split("\n"))
assert len(self.keywords_vocabulary) == config.keyword_vocabulary_size
@@ -245,8 +246,8 @@ class FunctionTokenDecoder(nn.Module):
interpro_id = self.interpro_ids[class_index]
annotation = FunctionAnnotation(
label=interpro_id,
start=position_index + 1, # zero-index -> one-index inclusive
end=position_index + 1, # zero-index -> one-index inclusive
start=position_index, # one-index inclusive (BOS shifts indexes +1)
end=position_index, # one-index inclusive
)
annotations.append(annotation)
@@ -300,8 +301,8 @@ class FunctionTokenDecoder(nn.Module):
for range_ in merge_ranges(ranges):
annotation = FunctionAnnotation(
label=keyword,
start=range_.start + 1, # zero-index -> one-index
end=range_.stop + 1 - 1, # zero-index excl -> one-index incl
start=range_.start, # one-index inclusive (BOS shifts indexes +1)
end=range_.stop - 1, # one-index exclusive -> one-index inclusive
)
annotations.append(annotation)
@@ -332,8 +333,8 @@ def _merge_annotations(
for range_ in merged_ranges:
annotation = FunctionAnnotation(
label=label,
start=range_.start + 1, # zero-index -> one-index
end=range_.stop - 1, # zero-index excl -> one-index incl
start=range_.start, # one-index inclusive (BOS shifts indexes +1)
end=range_.stop - 1, # one-index exclusive -> one-index inclusive
)
merged.append(annotation)
return merged

View File

@@ -235,12 +235,14 @@ class StructureTokenEncoder(nn.Module):
knn_sequence_id = (
node_gather(sequence_id.unsqueeze(-1), knn_edges).view(-1, E)
if sequence_id is not None
else torch.zeros(L, E, dtype=torch.int64, device=coords.device)
else torch.zeros(B * L, E, dtype=torch.int64, device=coords.device)
)
knn_affine_mask = node_gather(affine_mask.unsqueeze(-1), knn_edges).view(
-1, E
)
knn_chain_id = torch.zeros(L, E, dtype=torch.int64, device=coords.device)
knn_chain_id = torch.zeros(
B * L, E, dtype=torch.int64, device=coords.device
)
if residue_index is None:
res_idxs = knn_edges.view(-1, E)

View File

@@ -21,8 +21,8 @@ ModelBuilder = Callable[[torch.device | str], nn.Module]
def ESM3_sm_open_v0(device: torch.device | str = "cpu"):
model = (
ESM3(
with torch.device(device):
model = ESM3(
d_model=1536,
n_heads=24,
v_heads=256,
@@ -30,10 +30,7 @@ def ESM3_sm_open_v0(device: torch.device | str = "cpu"):
structure_encoder_name=ESM3_STRUCTURE_ENCODER_V0,
structure_decoder_name=ESM3_STRUCTURE_DECODER_V0,
function_decoder_name=ESM3_FUNCTION_DECODER_V0,
)
.to(device)
.eval()
)
).eval()
state_dict = torch.load(
data_root() / "data/weights/esm3_sm_open_v1.pth", map_location=device
)
@@ -42,13 +39,10 @@ def ESM3_sm_open_v0(device: torch.device | str = "cpu"):
def ESM3_structure_encoder_v0(device: torch.device | str = "cpu"):
model = (
StructureTokenEncoder(
with torch.device(device):
model = StructureTokenEncoder(
d_model=1024, n_heads=1, v_heads=128, n_layers=2, d_out=128, n_codes=4096
)
.to(device)
.eval()
)
).eval()
state_dict = torch.load(
data_root() / "data/weights/esm3_structure_encoder_v0.pth", map_location=device
)
@@ -57,9 +51,8 @@ def ESM3_structure_encoder_v0(device: torch.device | str = "cpu"):
def ESM3_structure_decoder_v0(device: torch.device | str = "cpu"):
model = (
StructureTokenDecoder(d_model=1280, n_heads=20, n_layers=30).to(device).eval()
)
with torch.device(device):
model = StructureTokenDecoder(d_model=1280, n_heads=20, n_layers=30).eval()
state_dict = torch.load(
data_root() / "data/weights/esm3_structure_decoder_v0.pth", map_location=device
)
@@ -68,7 +61,8 @@ def ESM3_structure_decoder_v0(device: torch.device | str = "cpu"):
def ESM3_function_decoder_v0(device: torch.device | str = "cpu"):
model = FunctionTokenDecoder().to(device).eval()
with torch.device(device):
model = FunctionTokenDecoder().eval()
state_dict = torch.load(
data_root() / "data/weights/esm3_function_decoder_v0.pth", map_location=device
)
@@ -84,7 +78,9 @@ LOCAL_MODEL_REGISTRY: dict[str, ModelBuilder] = {
}
def load_local_model(model_name: str, device: torch.device | str = "cpu") -> nn.Module:
def load_local_model(
model_name: str, device: torch.device = torch.device("cpu")
) -> nn.Module:
if model_name not in LOCAL_MODEL_REGISTRY:
raise ValueError(f"Model {model_name} not found in local model registry.")
return LOCAL_MODEL_REGISTRY[model_name](device)

11
esm/sdk/__init__.py Normal file
View File

@@ -0,0 +1,11 @@
import os
from esm.sdk.forge import ESM3ForgeInferenceClient
def client(
model="esm3-sm-open-v1",
url="https://forge.evolutionaryscale.ai",
token=os.environ.get("ESM_API_KEY", ""),
):
return ESM3ForgeInferenceClient(model, url, token)

View File

@@ -1,11 +1,11 @@
from __future__ import annotations
from abc import ABC
from typing import Sequence, TypeVar
from typing import Sequence
import attr
import torch
from attr import define
from attr import asdict, define
from esm.tokenization import (
TokenizerCollectionProtocol,
@@ -21,9 +21,13 @@ from esm.utils.types import (
)
class ProteinType(ABC):
...
## Basic Types
@define
class ESMProtein:
class ESMProtein(ProteinType):
# Tracks
sequence: str | None = None
secondary_structure: str | None = None
@@ -101,13 +105,15 @@ class ESMProtein:
entity_id=None,
residue_index=None,
insertion_code=None,
confidence=None if self.plddt is None else self.plddt.detach().cpu().numpy(),
confidence=None
if self.plddt is None
else self.plddt.detach().cpu().numpy(),
)
return protein_chain
@define
class ESMProteinTensor:
class ESMProteinTensor(ProteinType):
sequence: torch.Tensor | None = None
structure: torch.Tensor | None = None
secondary_structure: torch.Tensor | None = None
@@ -116,59 +122,33 @@ class ESMProteinTensor:
residue_annotations: torch.Tensor | None = None
coordinates: torch.Tensor | None = None
def _detect_attribute(self, func, msg):
mapped = {k: func(k, v) for k, v in asdict(self).items() if v is not None}
s = set(mapped.values())
if len(s) <= 0:
return None
if len(s) != 1:
raise ValueError(f"Either no tracks or inconsistent {msg}: {mapped}")
return next(iter(s))
def __len__(self) -> int:
if self.sequence is not None:
return self.sequence.size(0)
elif self.structure is not None:
return self.structure.size(0)
elif self.secondary_structure is not None:
return self.secondary_structure.size(0)
elif self.sasa is not None:
return self.sasa.size(0)
elif self.coordinates is not None:
return self.coordinates.size(0)
else:
raise ValueError("No track to determine length from.")
l = self._detect_attribute(lambda _, x: x.size(0), "length")
return l if l is not None else 0
@property
def device(self) -> str | torch.device:
device_ = None
tracks = [f.name for f in attr.fields(ESMProteinTensor)]
for track in tracks:
current_track: torch.Tensor | None = getattr(self, track)
if current_track is not None:
if device_ is not None and device_ != current_track.device:
raise ValueError(f"Inconsistent devices for track {track}.")
device_ = getattr(self, track).device
if device_ is None:
raise ValueError("No track to determine device from.")
return device_
def to(self, device: str | torch.device | None) -> ESMProteinTensor:
if device is None:
return self
device = torch.device(device)
d = self._detect_attribute(lambda _, x: x.device, "device")
assert d is not None
return d
def to(self, device_or_dtype: str | torch.device | torch.dtype) -> ESMProteinTensor:
def _to(name):
v = getattr(self, name)
if v is not None:
setattr(self, name, v.to(device))
setattr(self, name, v.to(device_or_dtype))
for n in [
"sequence",
"structure",
"secondary_structure",
"sasa",
"function",
"residue_annotations",
"coordinates",
]:
_to(n)
for n in attr.fields(ESMProteinTensor):
_to(n.name)
return self
@@ -202,6 +182,11 @@ class ESMProteinTensor:
)
@define
class ESMProteinError(Exception, ProteinType):
error_msg: str
## High Level Endpoint Types
@define
class GenerationConfig:
@@ -285,14 +270,10 @@ class ForwardAndSampleOutput(ForwardOutput):
topk_logprob: ForwardTrackData | None = None
# Which tokens correspond to top probability
topk_tokens: ForwardTrackData | None = None
per_residue_embedding: torch.Tensor | None = None
mean_embedding: torch.Tensor | None = None
ProteinType = TypeVar("ProteinType", bound=ESMProteinTensor | ESMProtein)
class ESM3InferenceClient(ABC):
def generate(self, input: ProteinType, config: GenerationConfig) -> ProteinType:
# This is the easiest and most flexible way to run ESM3. Generate will
@@ -302,6 +283,14 @@ class ESM3InferenceClient(ABC):
# if a ESMProteinTensor is provided, encode and decode are skipped
raise NotImplementedError
def batch_generate(
self,
inputs: Sequence[ProteinType],
configs: Sequence[GenerationConfig],
) -> Sequence[ProteinType]:
# Same as generate(...), but generates a batch of proteins at once.
raise NotImplementedError
def encode(self, input: ESMProtein) -> ESMProteinTensor:
# Encode allows for encoding RawRepresentation into TokenizedRepresentation.
# This runs the structure_token_encoder, as well as dealing with PDB => atom37 conversion

336
esm/sdk/forge.py Normal file
View File

@@ -0,0 +1,336 @@
import asyncio
from typing import Sequence
import requests
import torch
from esm.sdk.api import (
ESM3InferenceClient,
ESMProtein,
ESMProteinError,
ESMProteinTensor,
ForwardAndSampleOutput,
ForwardTrackData,
GenerationConfig,
ProteinType,
SamplingConfig,
SamplingTrackConfig,
)
from esm.utils.misc import maybe_list, maybe_tensor
from esm.utils.types import FunctionAnnotation
def _list_to_function_annotations(l) -> list[FunctionAnnotation] | None:
if l is None or len(l) <= 0:
return None
return [FunctionAnnotation(*t) for t in l]
class ESM3ForgeInferenceClient(ESM3InferenceClient):
def __init__(self, model, url, token):
self.model = model
self.url = url
self.token = token
self.headers = {"Authorization": f"Bearer {self.token}"}
def generate(self, input: ProteinType, config: GenerationConfig) -> ProteinType:
if isinstance(input, ESMProtein):
output = self.__generate_protein(input, config)
elif isinstance(input, ESMProteinTensor):
output = self.__generate_protein_tensor(input, config)
else:
return ESMProteinError(error_msg=f"Unkonw input type {type(input)}")
if (
isinstance(output, ESMProtein)
and isinstance(input, ESMProtein)
and config.track
not in [
"function",
"residue_annotations",
]
):
# Function and residue annotation encoding/decoding is lossy
# There is no guarantee that decoding encoded tokens will yield the same input
output.function_annotations = input.function_annotations
return output
def batch_generate(
self, inputs: list[ProteinType], configs: list[GenerationConfig]
) -> Sequence[ProteinType]:
"""Forge supports auto-batching. So batch_generate() for the Forge client
is as simple as running a collection of generate() in parallel using asyncio.
"""
loop = asyncio.get_event_loop()
async def _async_generate():
futures = [
loop.run_in_executor(None, self.generate, protein, config)
for protein, config in zip(inputs, configs)
]
return await asyncio.gather(*futures, return_exceptions=True)
results = loop.run_until_complete(_async_generate())
return [
r if not isinstance(r, BaseException) else ESMProteinError(str(r))
for r in results
]
def __generate_protein(
self,
input: ESMProtein,
config: GenerationConfig,
) -> ESMProtein | ESMProteinError:
req = {}
req["sequence"] = input.sequence
req["secondary_structure"] = input.secondary_structure
req["sasa"] = maybe_list(input.sasa)
if input.function_annotations is not None:
req["function"] = [x.to_tuple() for x in input.function_annotations]
req["coordinates"] = maybe_list(input.coordinates, convert_nan_to_none=True)
request = {
"model": self.model,
"inputs": req,
"track": config.track,
"invalid_ids": config.invalid_ids,
"schedule": config.schedule,
"num_steps": config.num_steps,
"temperature": config.temperature,
"top_p": config.top_p,
"condition_on_coordinates_only": config.condition_on_coordinates_only,
}
try:
data = self.__post("generate", request)
except RuntimeError as e:
return ESMProteinError(error_msg=str(e))
return ESMProtein(
sequence=data["outputs"]["sequence"],
secondary_structure=data["outputs"]["secondary_structure"],
sasa=data["outputs"]["sasa"],
function_annotations=_list_to_function_annotations(
data["outputs"]["function"]
),
coordinates=maybe_tensor(
data["outputs"]["coordinates"], convert_none_to_nan=True
),
plddt=maybe_tensor(data["outputs"]["plddt"]),
ptm=maybe_tensor(data["outputs"]["ptm"]),
)
def __generate_protein_tensor(
self,
input: ESMProteinTensor,
config: GenerationConfig,
) -> ESMProteinTensor | ESMProteinError:
req = {}
req["sequence"] = maybe_list(input.sequence)
req["structure"] = maybe_list(input.structure)
req["secondary_structure"] = maybe_list(input.secondary_structure)
req["sasa"] = maybe_list(input.sasa)
req["function"] = maybe_list(input.function)
req["coordinates"] = maybe_list(input.coordinates, convert_nan_to_none=True)
req["residue_annotation"] = maybe_list(input.residue_annotations)
request = {
"model": self.model,
"inputs": req,
"track": config.track,
"invalid_ids": config.invalid_ids,
"schedule": config.schedule,
"num_steps": config.num_steps,
"temperature": config.temperature,
"top_p": config.top_p,
"condition_on_coordinates_only": config.condition_on_coordinates_only,
}
try:
data = self.__post("generate_tensor", request)
except RuntimeError as e:
return ESMProteinError(error_msg=str(e))
def _field_to_tensor(field, convert_none_to_nan: bool = False):
if field not in data["outputs"]:
return None
return maybe_tensor(
data["outputs"][field], convert_none_to_nan=convert_none_to_nan
)
output = ESMProteinTensor(
sequence=_field_to_tensor("sequence"),
structure=_field_to_tensor("structure"),
secondary_structure=_field_to_tensor("secondary_structure"),
sasa=_field_to_tensor("sasa"),
function=_field_to_tensor("function"),
residue_annotations=_field_to_tensor("residue_annotation"),
coordinates=_field_to_tensor("coordinates", convert_none_to_nan=True),
)
return output
def forward_and_sample(
self, input: ESMProteinTensor, sampling_configuration: SamplingConfig
) -> ForwardAndSampleOutput:
req = {}
sampling_config = {}
embedding_config = None # TODO(zeming)
req["sequence"] = maybe_list(input.sequence)
req["structure"] = maybe_list(input.structure)
req["secondary_structure"] = maybe_list(input.secondary_structure)
req["sasa"] = maybe_list(input.sasa)
req["function"] = maybe_list(input.function)
req["coordinates"] = maybe_list(input.coordinates, convert_nan_to_none=True)
req["residue_annotation"] = maybe_list(input.residue_annotations)
def do_track(t: str):
track: SamplingTrackConfig | None
if (track := getattr(sampling_configuration, t, None)) is None:
sampling_config[t] = None
else:
sampling_config[t] = {
"temperature": track.temperature,
"top_p": track.top_p,
"only_sample_masked_tokens": track.only_sample_masked_tokens,
"invalid_ids": track.invalid_ids,
"topk_logprobs": track.topk_logprobs,
}
do_track("sequence")
do_track("structure")
do_track("secondary_structure")
do_track("sasa")
do_track("function")
request = {
"model": self.model,
"inputs": req,
"sampling_config": sampling_config,
"embedding_config": embedding_config,
}
data = self.__post("forward_and_sample", request)
def get(k, field):
if data[k] is None:
return None
v = data[k][field]
return torch.tensor(v) if v is not None else None
tokens = ESMProteinTensor(
sequence=get("sequence", "tokens"),
structure=get("structure", "tokens"),
secondary_structure=get("secondary_structure", "tokens"),
sasa=get("sasa", "tokens"),
function=get("function", "tokens"),
)
def get_track(field):
return ForwardTrackData(
sequence=get("sequence", field),
structure=get("structure", field),
secondary_structure=get("secondary_structure", field),
sasa=get("sasa", field),
function=get("function", field),
)
def operate_on_track(track: ForwardTrackData, fn):
apply = lambda x: fn(x) if x is not None else None
return ForwardTrackData(
sequence=apply(track.sequence),
structure=apply(track.structure),
secondary_structure=apply(track.secondary_structure),
sasa=apply(track.sasa),
function=apply(track.function),
)
logprob = get_track("logprobs")
output = ForwardAndSampleOutput(
protein_tensor=tokens,
logprob=logprob,
prob=operate_on_track(logprob, torch.exp),
entropy=get_track("entropy"),
topk_logprob=get_track("topk_logprobs"),
topk_tokens=get_track("topk_tokens"),
)
return output
def encode(self, input: ESMProtein) -> ESMProteinTensor:
tracks = {}
tracks["sequence"] = input.sequence
tracks["secondary_structure"] = input.secondary_structure
tracks["sasa"] = input.sasa
if input.function_annotations is not None:
tracks["function"] = [x.to_tuple() for x in input.function_annotations]
tracks["coordinates"] = maybe_list(input.coordinates, convert_nan_to_none=True)
request = {"inputs": tracks, "model": self.model}
data = self.__post("encode", request)
return ESMProteinTensor(
sequence=maybe_tensor(data["outputs"]["sequence"]),
structure=maybe_tensor(data["outputs"]["structure"]),
coordinates=maybe_tensor(
data["outputs"]["coordinates"], convert_none_to_nan=True
),
secondary_structure=maybe_tensor(data["outputs"]["secondary_structure"]),
sasa=maybe_tensor(data["outputs"]["sasa"]),
function=maybe_tensor(data["outputs"]["function"]),
residue_annotations=maybe_tensor(data["outputs"]["residue_annotation"]),
)
def decode(
self,
input: ESMProteinTensor,
) -> ESMProtein:
tokens = {}
tokens["sequence"] = maybe_list(input.sequence)
tokens["structure"] = maybe_list(input.structure)
tokens["secondary_structure"] = maybe_list(input.secondary_structure)
tokens["sasa"] = maybe_list(input.sasa)
tokens["function"] = maybe_list(input.function)
tokens["residue_annotation"] = maybe_list(input.residue_annotations)
tokens["coordinates"] = maybe_list(input.coordinates, convert_nan_to_none=True)
request = {
"model": self.model,
"inputs": tokens,
}
data = self.__post("decode", request)
return ESMProtein(
sequence=data["outputs"]["sequence"],
secondary_structure=data["outputs"]["secondary_structure"],
sasa=data["outputs"]["sasa"],
function_annotations=_list_to_function_annotations(
data["outputs"]["function"]
),
coordinates=maybe_tensor(
data["outputs"]["coordinates"], convert_none_to_nan=True
),
plddt=maybe_tensor(data["outputs"]["plddt"]),
ptm=maybe_tensor(data["outputs"]["ptm"]),
)
def __post(self, endpoint, request):
response = requests.post(
f"{self.url}/api/v1/{endpoint}",
json=request,
headers=self.headers,
)
if not response.ok:
raise RuntimeError(f"Failure in {endpoint}: {response.text}")
data = response.json()
# Nextjs puts outputs dict under "data" key.
# Lift it up for easier downstream processing.
if "outputs" not in data and "data" in data:
data = data["data"]
return data

View File

@@ -207,7 +207,7 @@ class InterProQuantizedTokenizer(EsmTokenizerBase):
interpro_ids = []
keywords = []
for label in labels:
match = re.match(r"IPR\d+", label)
match = re.search(r"IPR\d+", label)
if match and match.group() in self.interpro_to_index:
interpro_ids.append(match.group())
elif label in self._tfidf.vocab_to_index:

View File

@@ -1,10 +1,10 @@
from functools import cached_property
from pathlib import Path
from typing import Any
import pandas as pd
import torch
import torch.nn.functional as F
from cloudpathlib import AnyPath
from esm.tokenization.tokenizer_base import EsmTokenizerBase
from esm.utils.constants import esm3 as C
@@ -25,13 +25,13 @@ class ResidueAnnotationsTokenizer(EsmTokenizerBase):
@cached_property
def _description2label(self) -> dict[str, str]:
with Path(self.csv_path).open() as f: # type: ignore
with AnyPath(self.csv_path).open() as f: # type: ignore
df = pd.read_csv(f)
return dict(zip(df.label, df.label_clean))
@cached_property
def _labels(self) -> list[str]:
with Path(self.csv_path).open() as f: # type: ignore
with AnyPath(self.csv_path).open() as f: # type: ignore
df = pd.read_csv(f)
labels = (
df.groupby("label_clean")["count"]

View File

@@ -1,3 +1,4 @@
import os
from functools import cache
from pathlib import Path
@@ -106,6 +107,8 @@ def data_root():
]:
if (p := Path(path)).exists():
return p.parent
if "INFRA_PROVIDER" in os.environ:
return Path("")
# Try to download from hugginface if it doesn't exist
path = Path(snapshot_download(repo_id="EvolutionaryScale/esm3-sm-open-v1"))
return path

View File

@@ -189,10 +189,13 @@ def decode_sasa(
sasa_tokens: torch.Tensor,
sasa_tokenizer: SASADiscretizingTokenizer,
) -> list[float]:
_bos_eos_warn("SASA", sasa_tokens, sasa_tokenizer)
if sasa_tokens[0] != 0:
raise ValueError("SASA does not start with 0 corresponding to BOS token")
if sasa_tokens[-1] != 0:
raise ValueError("SASA does not end with 0 corresponding to EOS token")
sasa_tokens = sasa_tokens[1:-1]
return sasa_tokenizer.decode_float(sasa_tokens)
sasa = sasa_tokens.tolist()
return sasa
def decode_function_annotations(

View File

@@ -39,7 +39,7 @@ def encode_function_annotations(
supported_label = False
# Is it an InterPro label?
if match := re.match(r"IPR\d+", fa.label):
if match := re.search(r"IPR\d+", fa.label):
if match.group() in function_tokens_tokenizer.interpro_to_index:
ft_annotations.append(fa)
supported_label = True

View File

@@ -5,11 +5,11 @@ import re
from dataclasses import dataclass
from enum import IntEnum, auto
from functools import cached_property
from pathlib import Path
import networkx as nx
import numpy as np
import pandas as pd
from cloudpathlib import AnyPath
from esm.utils.constants import esm3 as C
@@ -36,7 +36,7 @@ def _parse_interpro2go(path: str) -> dict[str, list[str]]:
Returns:
Mapping from InterPro to list of associated GO terms.
"""
with Path(path).open("r") as f:
with AnyPath(path).open("r") as f:
text = f.read()
df = pd.Series(text.split("\n"), name="line").to_frame()
df = df[~df.line.str.startswith("!")]
@@ -131,7 +131,7 @@ class InterPro:
- "type": InterProEntryType representing the type of annotation.
- "name": Short name of the entry.
"""
with Path(self.entries_path).open("r") as f:
with AnyPath(self.entries_path).open("r") as f:
df = pd.read_csv(f, sep="\t")
assert all(
col in df.columns for col in ["ENTRY_AC", "ENTRY_TYPE", "ENTRY_NAME"]
@@ -178,7 +178,7 @@ class InterPro:
def graph(self) -> nx.DiGraph:
"""Reads the InterPro hierarchy of InterPro."""
graph = nx.DiGraph()
with Path(self.hierarchy_graph_path).open("r") as f:
with AnyPath(self.hierarchy_graph_path).open("r") as f:
parents = []
for line in f:
ipr = line.split("::", maxsplit=1)[0]

View File

@@ -1,6 +1,5 @@
from pathlib import Path
import numpy as np
from cloudpathlib import AnyPath
from esm.utils.types import PathLike
@@ -42,7 +41,7 @@ class LSHTokenized:
):
table_hyperplanes = None
if filepath is not None:
filepath = Path(filepath)
filepath = AnyPath(filepath)
if not filepath.exists():
raise FileNotFoundError(filepath)
table_hyperplanes = np.load(filepath) # type: ignore
@@ -83,7 +82,7 @@ class LSHBitstream:
):
table_hyperplanes = None
if filepath is not None:
filepath = Path(filepath)
filepath = AnyPath(filepath)
if not filepath.exists():
raise FileNotFoundError(filepath)
table_hyperplanes = np.load(filepath)

View File

@@ -4,6 +4,7 @@ from collections import Counter
from functools import cached_property
import numpy as np
from cloudpathlib import AnyPath
from scipy import sparse
@@ -13,10 +14,10 @@ class TFIDFModel:
"""
def __init__(self, vocabulary_path: str, idf_path: str):
with open(vocabulary_path, "r") as f:
with AnyPath(vocabulary_path).open("r") as f:
self.vocabulary = f.read().strip().split("\n")
with open(idf_path, "rb") as f:
with AnyPath(idf_path).open("rb") as f:
self.idf_ = np.load(f)
assert self.idf_.ndim == 1

View File

@@ -1,4 +1,6 @@
from typing import Callable
import os
from typing import Any, Callable, Sequence
from warnings import warn
import attr
import torch
@@ -7,8 +9,14 @@ from tqdm import tqdm
from esm.sdk.api import (
ESM3InferenceClient,
ESMProtein,
ESMProteinError,
ESMProteinTensor,
ForwardAndSampleOutput,
ForwardConfig,
ForwardOutput,
ForwardTrackData,
GenerationConfig,
ReturnLogitsConfig,
SamplingConfig,
SamplingTrackConfig,
)
@@ -16,170 +24,694 @@ from esm.tokenization import (
EsmTokenizerBase,
TokenizerCollectionProtocol,
)
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer,
)
from esm.utils.constants import esm3 as C
from esm.utils.noise_schedules import NOISE_SCHEDULE_REGISTRY
from esm.utils.sampling import (
_BatchedESMProteinTensor,
sample_function_logits,
sample_logits,
sample_residue_annotation_logits,
sample_sasa_logits,
)
from evolutionaryscale.utils.tensor import (
stack_variable_length_tensors,
)
def _trim_sequence_tensor_dataclass(o: Any, sequence_len: int):
"""Trim tensors on the sequence dimension.
This util assume that input tensor class has batch dimension.
"""
assert attr.has(o.__class__)
sliced = {}
for k, v in attr.asdict(o, recurse=False).items():
if v is None:
sliced[k] = None
elif isinstance(v, torch.Tensor):
# Trim padding.
sliced[k] = v[:, :sequence_len]
elif attr.has(v.__class__):
# Recursively slice the child attribute.
sliced[k] = _trim_sequence_tensor_dataclass(v, sequence_len)
else:
# Otherwise, simply copy the entire data bit over.
sliced[k] = v
return attr.evolve(o, **sliced)
def _slice_tensor_dataclass(o: Any, i: int, keep_dim: bool = False) -> Any:
"""Take a slice out of any attr defined Tensor objects along the batch dimension.
Args:
o: input tensor object to be sliced.
i: index of the row to be sliced.
keep_dim: whether to keep the batch dim after slicing.
For example, given a tensor of shape (5, 8), if keep_dim is True,
return a sliced tensor of shape (1, 8). Return a tensor of shape
(8,) instead if keep_dim is False. The default is False.
"""
assert attr.has(o.__class__)
sliced = {}
for k, v in attr.asdict(o, recurse=False).items():
if v is None:
sliced[k] = None
elif isinstance(v, torch.Tensor):
# Select the i-th row of each tensor.
row = v.select(0, i)
if keep_dim:
row = row.unsqueeze(0)
sliced[k] = row
elif attr.has(v.__class__):
# Recursively slice the child attribute.
sliced[k] = _slice_tensor_dataclass(v, i, keep_dim)
else:
# Otherwise, simply copy the entire data bit over.
sliced[k] = v
return attr.evolve(o, **sliced)
def iterative_sampling_raw(
client: ESM3InferenceClient,
input: ESMProtein,
config: GenerationConfig,
):
proteins: list[ESMProtein],
configs: list[GenerationConfig],
) -> list[ESMProtein | ESMProteinError]:
# Keep structure tokens
input_tokens = client.encode(input)
input_tokens = [client.encode(protein) for protein in proteins]
output_tokens = client.generate(input_tokens, config)
output_tokens_list = client.batch_generate(input_tokens, configs)
raw_protein = client.decode(output_tokens)
raw_proteins: list[ESMProtein | ESMProteinError] = []
for output_tokens in output_tokens_list:
if isinstance(output_tokens, ESMProteinTensor):
raw_proteins.append(client.decode(output_tokens))
elif isinstance(output_tokens, ESMProteinError):
raw_proteins.append(output_tokens)
else:
raise ValueError(f"Unknown output type {type(output_tokens)}")
for input_protein, raw_protein, config in zip(proteins, raw_proteins, configs):
if isinstance(raw_protein, ESMProteinError):
# If this generation errored out.
continue
if config.track not in ["function", "residue_annotations"]:
# Function and residue annotation encoding/decoding is lossy
# There is no guarantee that decoding encoded tokens will yield the same input
raw_protein.function_annotations = input_protein.function_annotations
return raw_proteins
def _make_masked_inputs(
track: str,
sequence_length: int,
tokenizers: TokenizerCollectionProtocol,
):
get_tokenizer: Callable[[str], EsmTokenizerBase] = lambda s: getattr(tokenizers, s)
if track == "coordinates":
dims = (sequence_length, 3, 3)
elif track == "attention_mask":
dims = (sequence_length,)
elif track == "function":
dims = (sequence_length, tokenizers.function.depth)
elif track == "residue_annotations":
dims = (sequence_length, C.MAX_RESIDUE_ANNOTATIONS)
else:
dims = (sequence_length,)
if track == "coordinates":
masked_tokens = torch.full(dims, torch.inf, dtype=torch.float)
elif track == "attention_mask":
masked_tokens = torch.full(dims, 1, dtype=torch.bool)
else:
masked_tokens = torch.full(
dims, get_tokenizer(track).mask_token_id, dtype=torch.long
)
masked_tokens[0] = get_tokenizer(track).bos_token_id
masked_tokens[-1] = get_tokenizer(track).eos_token_id
return masked_tokens
def _stack_protein_tensors(
input_tokens: list[ESMProteinTensor],
sequence_lengths: list[int],
tokenizers: TokenizerCollectionProtocol,
device: str | torch.device,
) -> _BatchedESMProteinTensor:
o = _BatchedESMProteinTensor()
def _stack_field(fn: str):
tensors = [getattr(tokens, fn) for tokens in input_tokens]
# Create all mask mock inputs for any tensors that are None.
tensors = [
t if t is not None else _make_masked_inputs(fn, l, tokenizers).to(device)
for t, l in zip(tensors, sequence_lengths)
]
if fn == "coordinates":
mask_token_id = torch.inf
else:
mask_token_id = getattr(tokenizers, fn).pad_token_id
setattr(
o,
fn,
stack_variable_length_tensors(
sequences=tensors,
constant_value=mask_token_id,
),
)
for f in attr.fields(ESMProteinTensor):
_stack_field(f.name)
return o
def _get_masked_positions(
track: str, tokens: torch.Tensor, mask_token_id: int
) -> torch.Tensor:
if track == "function":
return torch.all(tokens == mask_token_id, dim=-1).to(tokens.device)
else:
return tokens == mask_token_id
def _get_iterative_sampling_mask_for_prompt_and_step(
cur_sampled: _BatchedESMProteinTensor,
sequence_lengths: torch.Tensor,
total_to_sample: torch.Tensor,
step: int,
entropy: ForwardTrackData,
config: GenerationConfig,
tokenizers: TokenizerCollectionProtocol,
) -> torch.Tensor:
"""Get sampling mask based on forward output and config.
Returns:
Sampling mask and num of positions sampled.
"""
track_to_sample = config.track
tokens = getattr(cur_sampled, track_to_sample)
device = tokens.device
if track_to_sample not in ["function", "residue_annotations"]:
# Function and residue annotation encoding/decoding is lossy
# There is no guarantee that decoding encoded tokens will yield the same input
raw_protein.function_annotations = input.function_annotations
shape = tokens.shape
B, L = shape[0], shape[1]
return raw_protein
sampling_mask = torch.ones((B, L), dtype=torch.bool, device=device)
sampling_mask[:, 0] = False # BOS
# EOS and all padding tokens.
sampling_mask &= (
torch.arange(L).repeat(B, 1) < (sequence_lengths - 1).unsqueeze(-1)
).to(device)
is_mask = _get_masked_positions(
track_to_sample,
tokens,
getattr(tokenizers, track_to_sample).mask_token_id,
)
if not is_mask.any().item():
raise ValueError(f"Cannot sample {config.track} when input has no masks.")
sampling_mask = sampling_mask & is_mask
# Initialize schedule and masks
decoding_schedule = NOISE_SCHEDULE_REGISTRY[config.schedule]
# Calculate number of tokens to sample
still_masked = torch.sum(sampling_mask).int()
perc_masked_after_this_step = decoding_schedule(
torch.tensor((step + 1) / config.num_steps)
)
num_tokens_masked_after_this_step = (
perc_masked_after_this_step * total_to_sample
).int()
num_to_sample = still_masked - num_tokens_masked_after_this_step
track_entropy: torch.Tensor = getattr(
entropy, track_to_sample
) # (B, L) or (B, L, D)
if track_to_sample == "function":
track_entropy = track_entropy.sum(-1) # (B, L, D) -> (B, L)
track_entropy = track_entropy.masked_fill(
~sampling_mask, torch.finfo(track_entropy.dtype).max
)
_, indices = track_entropy.topk(num_to_sample, dim=-1, largest=False)
is_top_k = torch.zeros((B, L), dtype=torch.bool, device=device).scatter(
1, indices, True
)
where_to_sample = sampling_mask & is_top_k
if track_to_sample == "function":
where_to_sample = where_to_sample.unsqueeze(-1).expand(
B,
L,
tokenizers.function.depth,
) # (B, L) -> (B, L, D)
return where_to_sample
def iterative_sampling_tokens(
client: ESM3InferenceClient,
input_tokens: ESMProteinTensor,
config: GenerationConfig,
input_tokens: list[ESMProteinTensor],
configs: list[GenerationConfig],
tokenizers: TokenizerCollectionProtocol,
) -> ESMProteinTensor:
track_to_sample = config.track
) -> Sequence[ESMProteinTensor | ESMProteinError]:
devices = set([t.device for t in input_tokens])
if len(devices) > 1:
raise AttributeError(f"Input tokens on multiple devices {devices}")
# Get all tracks that require sampling
all_tracks = [
f.name for f in attr.fields(SamplingConfig) if "embedding" not in f.name
]
sampled_tokens = [attr.evolve(tokens) for tokens in input_tokens]
sequence_length = len(input_tokens)
device = input_tokens.device
# Clear structure tokens if user would like to condition only on coordinates.
for tokens, config in zip(sampled_tokens, configs):
if config.condition_on_coordinates_only and tokens.coordinates is not None:
tokens.structure = None
# Initialize schedule and masks
decoding_schedule = NOISE_SCHEDULE_REGISTRY[config.schedule]
sampled_tokens = attr.evolve(input_tokens) # Make a copy
# Total sequence lengths.
sequence_lengths = [len(tokens) for tokens in sampled_tokens]
# Figure out the number of tokens to be sampled for each prompt.
total_to_sample = []
for protein, seq_len, config in zip(input_tokens, sequence_lengths, configs):
track = config.track
if config.condition_on_coordinates_only and input_tokens.coordinates is not None:
sampled_tokens.structure = None
if getattr(protein, track) is None:
# We need to sample the entire track.
total_to_sample.append(seq_len - 2)
continue
sampling_mask = torch.ones(
sequence_length,
dtype=torch.bool,
device=device,
masked = _get_masked_positions(
track,
getattr(protein, track),
getattr(tokenizers, track).mask_token_id,
)
total_to_sample.append(torch.sum(masked))
# Different prompts may ask for different number of decoding steps.
# For now, we simply run the max number of steps.
# TODO: return completed proteins as soon as they are finished sampling.
max_num_steps = max([config.num_steps for config in configs])
# Now stack the list to make a single batched ESMProteinTensor.
batched_tokens = _stack_protein_tensors(
input_tokens,
sequence_lengths,
tokenizers,
devices.pop(),
)
sampling_mask[0] = False
sampling_mask[-1] = False
get_tokenizer: Callable[[str], EsmTokenizerBase] = lambda s: getattr(tokenizers, s)
if getattr(sampled_tokens, track_to_sample) is None:
if track_to_sample == "function":
dims = (sequence_length, tokenizers.function.depth)
elif track_to_sample == "residue_annotations":
dims = (sequence_length, C.MAX_RESIDUE_ANNOTATIONS)
else:
dims = (sequence_length,)
masked_tokens = torch.full(
dims,
get_tokenizer(track_to_sample).mask_token_id,
dtype=torch.long,
device=device,
)
if track_to_sample == "sequence":
masked_tokens[0] = tokenizers.sequence.cls_token_id # type: ignore
masked_tokens[-1] = tokenizers.sequence.eos_token_id # type: ignore
else:
masked_tokens[0] = get_tokenizer(track_to_sample).bos_token_id
masked_tokens[-1] = get_tokenizer(track_to_sample).eos_token_id
setattr(
sampled_tokens,
track_to_sample,
masked_tokens,
)
else:
is_mask: torch.Tensor = (
getattr(input_tokens, track_to_sample)
== get_tokenizer(track_to_sample).mask_token_id
)
if not is_mask.any().item():
raise ValueError(f"Cannot sample {config.track} when input has no masks.")
sampling_mask = sampling_mask & is_mask
# Remember sampled prompts that has somehow errored out.
errors: dict[int, ESMProteinError] = {}
# Decode
disable_tqdm = bool(os.environ.get("DISABLE_ITERATIVE_SAMPLING_TQDM", False))
for t in tqdm(range(max_num_steps), disable=disable_tqdm):
forward_out = _batch_forward(client, batched_tokens)
# Sample each prompt individually, since their configuration may
# be very different.
# TODO: downstream utils work with batch dimsension.
# Group by sampling configurations and sample those prompts together.
for i, config in enumerate(configs): # B
if i in errors:
# This prompts has errored out in previous steps.
# Skip.
continue
if config.track in ["coordinates", "residue_annotations"]:
errors[i] = ESMProteinError(
error_msg=f"Iterative sampling {config.track} is not supported."
)
continue
if t >= config.num_steps:
# Done sampling for this row.
continue
per_prompt_cur_sampled = _BatchedESMProteinTensor.from_protein_tensor(
batched_tokens.slice(i)
)
per_prompt_forward_out: ForwardOutput = _slice_tensor_dataclass(
forward_out, i, keep_dim=True
)
# Trim logits to proper sequence length for this prompt.
per_prompt_forward_out.logits = _trim_sequence_tensor_dataclass(
per_prompt_forward_out.logits,
# Note(jungong) : we can not smiply use sequence_lenths[i] here,
# what we want is for the sequence length of the logits to match
# that of the prompt, which may or may not be padded, depending on
# whether the padding was done locally with the open source model
# (where per_prompt_cur_sampled is already padded) or by
# BatchedForwardRunner (where per_prompt_cur_sampled is not padded).
len(per_prompt_cur_sampled),
)
track_sample_config = SamplingTrackConfig()
track_sample_config.invalid_ids = config.invalid_ids
track_sample_config.temperature = config.temperature
track_sample_config.top_p = config.top_p
sampling_config = SamplingConfig(**{config.track: track_sample_config}) # type: ignore
# Sampling has to be done per-prompt, since sampling configs
# are likely be different for different prompts.
per_prompt_forward_and_sample_output = _sample_per_prompt(
per_prompt_cur_sampled,
per_prompt_forward_out,
sampling_config,
tokenizers,
)
# All positions sampled after _sample_per_prompt() above.
# (B, L) & (B, L, D)
per_prompt_new_sampled = per_prompt_forward_and_sample_output.protein_tensor
# Find the positions we should sample this round.
assert per_prompt_forward_and_sample_output.entropy is not None
try:
where_to_sample = _get_iterative_sampling_mask_for_prompt_and_step(
per_prompt_cur_sampled,
torch.tensor(sequence_lengths[i]),
torch.tensor(total_to_sample[i]),
t,
per_prompt_forward_and_sample_output.entropy,
config,
tokenizers,
)
except ValueError as e:
errors[i] = ESMProteinError(error_msg=str(e))
continue
where_to_sample.to(input_tokens[0].device)
old_track_samples = getattr(per_prompt_cur_sampled, config.track)
new_track_samples = getattr(per_prompt_new_sampled, config.track)
# Iterative sampling by picking the tokens sampled this round
# from new_track_samples to old_track_samples.
new_track_samples = torch.where(
where_to_sample, new_track_samples, old_track_samples
)
# Update the corresponding row with new data.
getattr(batched_tokens, config.track)[i, ...] = new_track_samples[0]
# Un-pack to a list of single ProteinTypes.
output_tokens = [
batched_tokens.slice(i, sequence_len=sequence_lengths[i])
if i not in errors
else errors[i]
for i in range(len(input_tokens))
]
# Do not update tracks that were not sampled (e.g. keep None instead of masks)
for inputs, outputs, config in zip(input_tokens, output_tokens, configs):
if isinstance(outputs, ESMProteinError):
continue
# First restore coordinates field.
# We know coordinates can never be iteratively sampled.
setattr(outputs, "coordinates", getattr(inputs, "coordinates"))
# Maybe restore all the other fields.
for f in attr.fields(SamplingConfig):
if "embedding" in f.name:
continue
if f.name != config.track:
setattr(outputs, f.name, getattr(inputs, f.name))
return output_tokens
def _batch_forward(
client: ESM3InferenceClient,
protein: _BatchedESMProteinTensor,
):
# Forward pass
return client._forward(
protein,
ForwardConfig(
ReturnLogitsConfig(
sequence=True,
structure=True,
secondary_structure=True,
sasa=True,
function=True,
residue_annotations=True,
),
return_embeddings=True,
),
)
def _sample_per_prompt(
protein: _BatchedESMProteinTensor,
forward_output: ForwardOutput,
sampling_config: SamplingConfig,
tokenizers: TokenizerCollectionProtocol,
) -> ForwardAndSampleOutput:
def maybe_clone(x: torch.Tensor | None) -> torch.Tensor | None:
return x.clone() if x is not None else None
L = sequence_length - 2
positions_sampled = 0
for t in tqdm(range(config.num_steps)):
# Single step sampling at all positions
track_sample_config = SamplingTrackConfig()
track_sample_config.invalid_ids = config.invalid_ids
track_sample_config.temperature = config.temperature
track_sample_config.top_p = config.top_p
sampling_config = SamplingConfig(**{track_to_sample: track_sample_config}) # type: ignore
forward_and_sample_output = client.forward_and_sample(
sampled_tokens, sampling_config
)
new_samples = forward_and_sample_output.protein_tensor
# Calculate number of tokens to sample
perc_masked = decoding_schedule(torch.tensor((t + 1) / config.num_steps))
num_to_sample = int((1 - perc_masked) * L) - positions_sampled
positions_sampled += num_to_sample
# Select tokens based on lowest entropy
if track_to_sample in ["function", "residue_annotations"]:
# TODO: Implement iterative decoding for function and residue_annotations
# TODO: Fix encode/decode of interpro tokens (not yet supported)
sampled_tokens.function = maybe_clone(input_tokens.function)
sampled_tokens.residue_annotations = maybe_clone(
input_tokens.residue_annotations
)
if track_to_sample in track_to_sample:
raise NotImplementedError(
f"Iterative decoding for {track_to_sample} is not supported yet."
)
# Sampling
tokens_dir = {}
track_sampling_metadata_dir: dict[str, dict | None] = {}
for track in ["sequence", "structure", "secondary_structure"]:
config = getattr(sampling_config, track)
if config is None:
tokens_dir[track] = maybe_clone(getattr(protein, track))
continue
sampling_mask = sampling_mask & (
getattr(sampled_tokens, track_to_sample)
== get_tokenizer(track_to_sample).mask_token_id
sampling_metadata = _sample_track(
logits=getattr(forward_output.logits, track),
tokens=getattr(protein, track),
sampling_track_config=config,
mask_idx=getattr(tokenizers, track).mask_token_id,
)
tokens_dir[track] = sampling_metadata.pop("sampled_tokens") # (L,)
track_sampling_metadata_dir[track] = sampling_metadata
track_entropy: torch.Tensor = getattr(
forward_and_sample_output.entropy, track_to_sample
# Sample SASA seperately
config = getattr(sampling_config, "sasa")
track_sampling_metadata_dir["sasa"] = None
if config is not None:
if config.topk_logprobs > 0:
warn("For SASA sampling, 'topk_logprobs' is expected to be 0.")
sasa_logits = forward_output.logits.sasa[0, ...] # type: ignore
sasa_value = sample_sasa_logits(sasa_logits, protein.sasa[0, ...]) # type: ignore
tokens_dir["sasa"] = sasa_value
probs = sasa_logits.softmax(dim=-1)
entropy = -(probs * sasa_logits.log_softmax(-1)).sum(-1)
track_sampling_metadata_dir["sasa"] = {"entropy": entropy}
# Sample function and residue annotations separately
config = getattr(sampling_config, "function")
if config is None:
tokens_dir["function"] = maybe_clone(getattr(protein, "function"))
tokens_dir["residue_annotations"] = maybe_clone(
getattr(protein, "residue_annotations")
)
track_entropy = track_entropy.masked_fill(
~sampling_mask, torch.finfo(track_entropy.dtype).max
else:
sampling_metadata = _sample_function_track(
tokenizers.function,
tokens=getattr(protein, "function"),
logits=getattr(forward_output.logits, "function"),
sampling_track_config=config,
)
_, indices = track_entropy.topk(num_to_sample, dim=-1, largest=False)
is_top_k = ~(
torch.arange(sequence_length, device=device)[:, None] != indices[None, :]
).all(-1)
tokens_to_sample = sampling_mask & is_top_k
tokens_dir["function"] = sampling_metadata.pop("sampled_tokens") # (L, D)
track_sampling_metadata_dir["function"] = sampling_metadata
old_track_samples = getattr(sampled_tokens, track_to_sample)
new_track_samples = getattr(new_samples, track_to_sample)
new_track_samples = torch.where(
tokens_to_sample, new_track_samples, old_track_samples
sampled_tokens, _ = sample_residue_annotation_logits(
logits=forward_output.residue_annotation_logits # type: ignore
)
tokens_dir["residue_annotations"] = sampled_tokens # (L, MAX_R)
setattr(sampled_tokens, track_to_sample, new_track_samples)
# Do not update tracks that were not sampled (e.g. keep None instead of masks)
for track in all_tracks:
if track != track_to_sample:
setattr(
sampled_tokens,
track,
maybe_clone(getattr(input_tokens, track)),
# Format output
forward_and_sample_output_dir = {}
forward_and_sample_output_dir["protein_tensor"] = ESMProteinTensor(**tokens_dir)
for property in [
"entropy",
"prob",
"logprob",
"top_prob",
"topk_logprob",
"topk_tokens",
]:
is_all_none = True
forward_track_data_dir = {}
for track in track_sampling_metadata_dir.keys():
values = track_sampling_metadata_dir[track]
if values is not None and values.get(property, None) is not None:
forward_track_data_dir[track] = values.get(property, None)
is_all_none = False
if not is_all_none:
forward_and_sample_output_dir[property] = ForwardTrackData(
**forward_track_data_dir
)
else:
forward_and_sample_output_dir[property] = None
return sampled_tokens
per_res_embed = (
forward_output.embeddings # type: ignore
if sampling_config.return_per_residue_embeddings
else None
)
mean_embedding = (
# [B, L, D] -> [B, D]
forward_output.embeddings[0].mean(dim=1) # type: ignore
if sampling_config.return_mean_embedding
else None
)
return ForwardAndSampleOutput(
per_residue_embedding=per_res_embed,
mean_embedding=mean_embedding,
**forward_and_sample_output_dir,
)
def _sample_track(
logits: torch.Tensor,
tokens: torch.Tensor,
sampling_track_config: SamplingTrackConfig,
mask_idx: int,
) -> dict[str, torch.Tensor]:
"""Works with inputs that have batch dimension."""
# Sample in all positions
temperature = sampling_track_config.temperature
# We have to trim the logits and sampled tokens at potentially padded slots
# since the logits may be computed with a longer padded batch, while tokens
# are the original input sequence.
sampled_tokens = sample_logits(
logits, temperature=temperature, top_p=sampling_track_config.top_p
)
log_probs = logits.log_softmax(-1)
# Do not sample at BOS and EOS tokens
sampling_mask = torch.ones_like(tokens, dtype=torch.bool) # (B, L, )
sampling_mask[:, 0] = False
sampling_mask[:, -1] = False
# Do not sample at special token positions but allow sampling at mask token
special_minus_mask = list(set(sampling_track_config.invalid_ids) - {mask_idx})
if len(special_minus_mask) > 0:
special_tokens = torch.tensor(special_minus_mask, device=tokens.device)
assert special_tokens.numel() > 0
sampling_mask = sampling_mask & (
tokens[..., None] != special_tokens[None, :]
).all(-1)
# Keep only samples from masked positions (if specified)
if sampling_track_config.only_sample_masked_tokens:
masked_tokens = tokens == mask_idx
sampling_mask = sampling_mask & masked_tokens
sampled_tokens = torch.where(sampling_mask, sampled_tokens, tokens)
return _compute_track_metadata(
sampled_tokens,
log_probs,
sampling_mask,
top_k=sampling_track_config.topk_logprobs,
)
def _sample_function_track(
function_tokenizer: InterProQuantizedTokenizer,
tokens: torch.Tensor,
logits: torch.Tensor,
sampling_track_config: SamplingTrackConfig,
) -> dict[str, torch.Tensor]:
"""Works with inputs that have batch dimension."""
# Do not sample at BOS and EOS tokens
sampling_mask = torch.ones_like(tokens, dtype=torch.bool)[..., 0] # (B, L)
sampling_mask[..., 0] = False
sampling_mask[..., -1] = False
sampled_tokens, logprobs = sample_function_logits(
logits,
function_tokenizer,
top_p=sampling_track_config.top_p,
temperature=sampling_track_config.temperature,
)
if sampling_track_config.only_sample_masked_tokens:
is_mask = torch.all(
tokens == function_tokenizer.mask_token_id, dim=-1
) # (B, L)
sampling_mask = sampling_mask & is_mask
sampled_tokens = torch.where(
sampling_mask[..., None].expand_as(sampled_tokens), sampled_tokens, tokens
) # (B, L, D)
# Set logprobs for non-sampled tokens to 0
logprobs_null = torch.full_like(logprobs, -torch.inf) # (B, L, D, V)
logprobs_null = torch.scatter(
logprobs_null, -1, tokens[..., None], torch.zeros_like(logprobs_null)[..., [0]]
)
logprobs = torch.where(
sampling_mask[..., None, None].expand_as(logprobs), logprobs, logprobs_null
) # (B, L, D, V)
function_metadata = _compute_track_metadata(
sampled_tokens,
logprobs,
sampling_mask,
top_k=sampling_track_config.topk_logprobs,
)
# Consider the entropy of the joint distribution of all function tokens at each position
function_metadata["entropy"] = function_metadata["entropy"].sum(
-1
) # (B, L, D) -> (B, L)
return function_metadata
def _compute_track_metadata(
sampled_tokens: torch.Tensor,
log_probs: torch.Tensor,
sampling_mask: torch.Tensor,
top_k: int,
) -> dict:
"""Works with inputs that have batch dimension."""
probs = torch.exp(log_probs) # (B, L)
entropy = torch.distributions.Categorical(logits=log_probs).entropy() # (B, L)
# Only compute probabilities for sampled tokens
sampled_logprob = torch.zeros_like(sampled_tokens, dtype=log_probs.dtype) # (B, L)
if sampled_tokens.dim() > sampling_mask.dim():
assert sampled_tokens.dim() == 3 # (B, L, D)
assert sampling_mask.dim() == 2 # (B, L)
sampling_mask = sampling_mask[..., None].expand_as(sampled_tokens)
sampled_tokens_valid = sampled_tokens[sampling_mask]
sampled_log_probs_valid = log_probs[sampling_mask, sampled_tokens_valid]
sampled_logprob[sampling_mask] = sampled_log_probs_valid
# Calculate extra metadata
sampled_prob = torch.exp(sampled_logprob)
top_prob = torch.max(probs, dim=-1).values
topk_logprobs, topk_tokens = torch.topk(log_probs, top_k, dim=-1)
topk_logprobs = None if top_k == 0 else topk_logprobs
topk_tokens = None if top_k == 0 else topk_tokens
return {
"entropy": entropy,
"sampled_tokens": sampled_tokens,
"prob": sampled_prob,
"logprob": sampled_logprob,
"top_prob": top_prob,
"topk_logprob": topk_logprobs,
"topk_tokens": topk_tokens,
}

View File

@@ -3,6 +3,7 @@ import torch
import torch.nn.functional as F
from esm.sdk.api import (
ESMProteinTensor,
SamplingConfig,
SamplingTrackConfig,
)
@@ -13,7 +14,92 @@ from esm.tokenization import (
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer,
)
from esm.utils.constants.esm3 import MAX_RESIDUE_ANNOTATIONS
from esm.utils.constants.esm3 import (
MAX_RESIDUE_ANNOTATIONS,
SASA_DISCRETIZATION_BOUNDARIES,
)
# Number of dimensions for each protein tensor field without the batch dimension.
_DIMS: dict[str, int] = {
"sequence": 1,
"structure": 1,
"secondary_structure": 1,
"sasa": 1,
"function": 2,
"residue_annotations": 2,
"coordinates": 3,
}
class _BatchedESMProteinTensor(ESMProteinTensor):
@staticmethod
def from_protein_tensor(protein: ESMProteinTensor):
def _maybe_unsqueeze(x: torch.Tensor | None):
return x.unsqueeze(0) if x is not None else None
return _BatchedESMProteinTensor(
sequence=_maybe_unsqueeze(protein.sequence),
structure=_maybe_unsqueeze(protein.structure),
secondary_structure=_maybe_unsqueeze(protein.secondary_structure),
sasa=_maybe_unsqueeze(protein.sasa),
function=_maybe_unsqueeze(protein.function),
residue_annotations=_maybe_unsqueeze(protein.residue_annotations),
coordinates=_maybe_unsqueeze(protein.coordinates),
)
def __len__(self) -> int:
def get_len(k, v) -> int:
assert len(v.shape) == _DIMS[k] + 1
return v.size(1)
l = self._detect_attribute(get_len, "length")
return l if l is not None else 0
@property
def batch_size(self) -> int:
def get_batch_size(k, v) -> int:
assert len(v.shape) == _DIMS[k] + 1
return v.size(0)
d = self._detect_attribute(get_batch_size, "batch size")
assert d is not None
return d
def slice(
self,
i: int,
sequence_len: int | None = None,
) -> ESMProteinTensor:
def _maybe_slice(x: torch.Tensor | None):
if x is None:
return None
row = x[i]
if sequence_len is not None:
row = row[:sequence_len]
return row
return ESMProteinTensor(
sequence=_maybe_slice(self.sequence),
structure=_maybe_slice(self.structure),
secondary_structure=_maybe_slice(self.secondary_structure),
sasa=_maybe_slice(self.sasa),
function=_maybe_slice(self.function),
residue_annotations=_maybe_slice(self.residue_annotations),
coordinates=_maybe_slice(self.coordinates),
)
def set_slice(self, i: int, slice: ESMProteinTensor):
"""Update the i-th slice of this tensor data class."""
for f in attr.fields(ESMProteinTensor):
s = getattr(self, f.name)
v = getattr(slice, f.name)
assert v is None or (
v is not None and s is not None
), f"Trying to set a slice on None tensor ({f.name})."
if v is not None:
s[i, ...] = v
def get_default_sampling_config(tokenizers: TokenizerCollection) -> SamplingConfig:
@@ -79,7 +165,8 @@ def sample_function_logits(
temperature: float | torch.Tensor = 1.0,
p_none_threshold: float = 0.05,
) -> tuple[torch.Tensor, torch.Tensor]:
[L, D, V] = logits.shape
"""Works with inputs that have batch dimension."""
[B, L, D, V] = logits.shape
assert D == tokenizer.depth
if top_p < 1.0:
@@ -87,18 +174,26 @@ def sample_function_logits(
temperature = torch.ones_like(logits[..., 0]) * temperature
log_p = F.log_softmax(logits / temperature[..., None], dim=-1) # (L, D, V)
log_p = F.log_softmax(logits / temperature[..., None], dim=-1) # (B, L, D, V)
# Choose which positions have no predicted function.
log_p_nones = log_p[..., tokenizer.vocab_to_index["<none>"]] # (L, D)
none_index = tokenizer.vocab_to_index["<none>"]
log_p_nones = log_p[..., none_index] # (B, L, D)
p_none = torch.exp(log_p_nones).mean(dim=-1) # "Ensemble of <none> predictions"
where_none = p_none > p_none_threshold # (L, )
where_none = p_none > p_none_threshold # (B, L)
# Set probability of <none> to 0 for all not-none positions
none_index = tokenizer.vocab_to_index["<none>"]
log_p[~where_none, :, none_index] = -torch.inf
batch_size, seq_len, depth = log_p.shape[:-1]
expanded_where_not_none = ~where_none.unsqueeze(-1).unsqueeze(-1) # (B, L, 1, 1)
expanded_where_not_none = expanded_where_not_none.expand(
batch_size, seq_len, depth, 1
) # (B, L, D, 1)
indices = torch.arange(log_p.shape[-1], device=log_p.device) # (V,)
mask = indices == none_index # (V,)
mask = expanded_where_not_none & mask # (B, L, D, 1) x (V,) -> (B, L, D, V)
log_p[mask] = -torch.inf
ids = torch.argmax(log_p, dim=-1) # (L, D)
ids = torch.argmax(log_p, dim=-1) # (B, L, D)
ids[where_none, :] = tokenizer.vocab_to_index["<none>"]
return ids, log_p
@@ -110,10 +205,10 @@ def sample_residue_annotation_logits(
# Take top residue annotations
top_residue_annotations_idx = logits.argsort(dim=-1, descending=True)[
..., :MAX_RESIDUE_ANNOTATIONS
] # (L, MAX_R)
] # (B, L, MAX_R)
top_residue_annotations_logprobs = torch.gather(
F.logsigmoid(logits), -1, top_residue_annotations_idx
) # (L, MAX_R)
) # (B, L, MAX_R)
top_residue_annotations_probs = top_residue_annotations_logprobs.exp()
# Keep only positive predictions
is_negative = top_residue_annotations_probs < annotation_threshold
@@ -124,6 +219,26 @@ def sample_residue_annotation_logits(
return top_residue_annotations_idx, top_residue_annotations_logprobs
def sample_sasa_logits(
logits: torch.Tensor,
tokens: torch.Tensor,
) -> torch.Tensor:
sasa_probs = torch.nn.functional.softmax(logits, dim=-1)
max_prob_idx = torch.argmax(sasa_probs, dim=-1)
sasa_bins = torch.tensor([0] + SASA_DISCRETIZATION_BOUNDARIES, dtype=torch.float)
sasa_bins = (sasa_bins[:-1] + sasa_bins[1:]) / 2
sasa_bins = sasa_bins.to(sasa_probs.device)
# Adjust sasa_values based on max_prob_idx conditions
sasa_value = torch.sum(sasa_probs[..., 3:-1] * sasa_bins, dim=-1)
sasa_value[tokens == 0] = float("-inf")
sasa_value[tokens == 1] = float("-inf")
sasa_value[tokens == 2] = float("-inf")
sasa_value[max_prob_idx == 18] = float("inf")
return sasa_value
def top_p_logits(
logits: torch.Tensor,
top_p: float | torch.Tensor,

View File

@@ -9,11 +9,6 @@ from typing_extensions import Self
from esm.utils.misc import fp32_autocast_context
def maybe_compile(func, x: torch.Tensor):
# Sometimes, torch compile seems to give issues for CPU tensors...
return torch.compile(func) if x.device.type == "cuda" else func
@T.runtime_checkable
class Rotation(T.Protocol):
@classmethod
@@ -154,9 +149,7 @@ class RotationMatrix(Rotation):
x_axis: torch.Tensor, xy_plane: torch.Tensor, eps: float = 1e-12
) -> RotationMatrix:
# A low eps here is necessary for good stability!
return RotationMatrix(
maybe_compile(_graham_schmidt, x_axis)(x_axis, xy_plane, eps)
)
return RotationMatrix(_graham_schmidt(x_axis, xy_plane, eps))
@dataclass(frozen=True)

View File

@@ -17,6 +17,7 @@ from biotite.application.dssp import DsspApp
from biotite.database import rcsb
from biotite.structure.io.npz import NpzFile
from biotite.structure.io.pdb import PDBFile
from cloudpathlib import CloudPath
from scipy.spatial.distance import pdist, squareform
from torch import Tensor
@@ -38,7 +39,7 @@ CHAIN_ID_CONST = "A"
ArrayOrTensor = TypeVar("ArrayOrTensor", np.ndarray, Tensor)
PathLike = Union[str, Path]
PathLike = Union[str, Path, CloudPath]
PathOrBuffer = Union[PathLike, io.StringIO]

View File

@@ -5,7 +5,9 @@ from dataclasses import dataclass
from pathlib import Path
from typing import Union
PathLike = Union[str, Path]
from cloudpathlib import CloudPath
PathLike = Union[str, Path, CloudPath]
PathOrBuffer = Union[PathLike, io.StringIO]

801
examples/esmprotein.ipynb Normal file
View File

@@ -0,0 +1,801 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Input tracks of `ESMProtein`\n",
"\n",
"ESM3 is a frontier generative model for biology, able to jointly reason across three fundamental biological properties of proteins: sequence, structure, and function. These three data modalities are represented as tracks of discrete tokens at the input and output of ESM3. You can present the model with a combination of partial inputs across the tracks, and ESM3 will provide output predictions for all the tracks."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this notebook, we will familiarize ourselves with the `ESMProtein` class, which holds multiple properties of a protein representing sequence, structure, and function. The ESM3 models use these properties from the input (prompts) and generate them as part of the output."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"An `ESMProtein` has 5 attributes that represent input (promptable) tracks:\n",
"\n",
"* `sequence`: amino acid sequence\n",
"* `coordinates`: 3D coordinates of atoms in each amino acid of the protein\n",
"* `secondary_structure`: [8-class secondary structure](https://en.wikipedia.org/wiki/Protein_secondary_structure#DSSP_classification) (SS8)\n",
"* `sasa`: [solvent-accessible surface area](https://en.wikipedia.org/wiki/Accessible_surface_area) (SASA)\n",
"* `function_annotations`: function annotations derived from [InterPro](https://www.ebi.ac.uk/interpro/)\n",
"\n",
"You can prompt an ESM3 model by setting any subset of these tracks to be partially unmasked when calling the model with an `ESMProtein` instance."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"One way to create an `ESMProtein` object is from a pdb id and chain id from [RCSB](https://www.rcsb.org). Below, we first create a `ProteinChain` with the pdb id and chain id and then create an `ESMProtein` from it. This will populate the `sequence` and `coordinates` properties."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Install esm\n",
"! pip install esm"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from biotite.database import rcsb\n",
"from esm.utils.structure.protein_chain import ProteinChain\n",
"from esm.sdk.api import ESMProtein\n",
"from esm.utils.types import FunctionAnnotation\n",
"\n",
"pdb_id = \"1cm4\"\n",
"chain_id = \"A\"\n",
"\n",
"# Create a protein using a pdb format file from RCSB\n",
"# Note: instead of the next two lines, we could use\n",
"# protein_chain = ProteinChain.from_rcsb(pdb_id, chain_id)\n",
"# but in future implementations, this function may use the mmcif file\n",
"# which would throw off some indices later on in this notebook\n",
"str_io = rcsb.fetch(pdb_id, \"pdb\")\n",
"protein_chain = ProteinChain.from_pdb(str_io, chain_id=chain_id, id=pdb_id)\n",
"protein = ESMProtein.from_protein_chain(protein_chain)\n",
"\n",
"## We can also load from a local pdb file by passing its path\n",
"# protein_chain = ProteinChain.from_pdb('xxxx.pdb', chain_id=chain_id, id=pdb_id)\n",
"# The chain_id and id arguments are optional and will be inferred if None"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### `sequence`\n",
"The `sequence` track contains a sequence of 1-letter representation of the amino acids in the protein:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(protein.sequence)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### `coordinates`\n",
"\n",
"\n",
"`coordinates` contains the 3D coordinates of atoms in the protein. It contains a tensor of shape `(n_residues, 37, 3)`, where \n",
"\n",
"* `n_residues` is the number of amino acids in the protein.\n",
"* `37` is the maximum possible number of atoms in an amino acid, represented in the atom37 representation. If certain atoms are not present in the structure, they will show up as `nan`.\n",
"* `3` is for 3D (x,y,z) coordinates. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(protein.coordinates.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(protein.coordinates)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We define two functions below that visualize the `coordinates` attribute: we define two functions below (as before, there is no need to go through them)\n",
"* `visualize_3D_coordinates()` visualizes directly from the coordinates tensor by creating a pdb file with all alanines\n",
"* `visualize_3D_protein()` visualizes from the `ESMProtein` instance, which has the correct amino acids"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"\n",
"Functions for visualizing 3D structure\n",
"\"\"\"\n",
"\n",
"! pip install py3Dmol\n",
"\n",
"import py3Dmol\n",
"\n",
"def visualize_pdb(pdb_string):\n",
" view = py3Dmol.view(width=400, height=400)\n",
" view.addModel(pdb_string, \"pdb\")\n",
" view.setStyle({'cartoon': {'color': 'spectrum'}})\n",
" view.zoomTo()\n",
" view.render()\n",
" view.center()\n",
" return view\n",
"\n",
"def visualize_3D_coordinates(coordinates):\n",
" \"\"\"\n",
" This uses all Alanines\n",
" \"\"\"\n",
" protein_with_same_coords = ESMProtein(coordinates=coordinates)\n",
" # pdb with all alanines\n",
" pdb_string = protein_with_same_coords.to_pdb_string()\n",
" return visualize_pdb(pdb_string)\n",
"\n",
"def visualize_3D_protein(protein):\n",
" pdb_string = protein.to_pdb_string()\n",
" return visualize_pdb(pdb_string)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# visualize from just the coordinates\n",
"visualize_3D_coordinates(protein.coordinates)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# visualize from sequence and coordinates\n",
"visualize_3D_protein(protein)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### `secondary_structure`\n",
"\n",
"The `secondary_structure` property contains a representation of the secondary structure. At a high level of categorization, we can classify each amino acid as belonging into three classes: alpha helices, beta sheets, and coil, which we could see in the previous 3D visualization.\n",
"\n",
"`ESMProtein` uses a [8-class secondary structure](https://en.wikipedia.org/wiki/Protein_secondary_structure#DSSP_classification) that can be computed with [dssp](https://swift.cmbi.umcn.nl/gv/dssp/) given 3D atom coordinates. Since installing dssp is a separate process from installing the `esm` package, in this notebook, we show how to compute the coarser 3-class classification using biotite's [annotate_sse](https://www.biotite-python.org/apidoc/biotite.structure.annotate_sse.html). We can set the `secondary_structure` property with this 3-class classification."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from biotite.structure import annotate_sse\n",
"\n",
"def get_approximate_ss(protein_chain: ProteinChain):\n",
" # get biotite's ss3 representation\n",
" ss3_arr = annotate_sse(protein_chain.atom_array)\n",
" biotite_ss3_str = ''.join(ss3_arr)\n",
"\n",
" # translate into ESM3's representation\n",
" translation_table = str.maketrans({\n",
" 'a': 'H', # alpha helix\n",
" 'b': 'E', # beta sheet\n",
" 'c': 'C', # coil\n",
" })\n",
" esm_ss3 = biotite_ss3_str.translate(translation_table)\n",
" return esm_ss3"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"protein.secondary_structure = get_approximate_ss(protein_chain)\n",
"print(protein.secondary_structure)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The next cell defines a function that visualizes the secondary structure and there is no need to read them!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# ! pip install matplotlib\n",
"\n",
"# Slightly modified version of secondary structure plotting code from\n",
"# https://www.biotite-python.org/examples/gallery/structure/transketolase_sse.html\n",
"# Code source: Patrick Kunzmann\n",
"# License: BSD 3 clause\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.patches import Rectangle\n",
"import biotite\n",
"import biotite.sequence as seq\n",
"import biotite.sequence.graphics as graphics\n",
"\n",
"# Create 'FeaturePlotter' subclasses\n",
"# for drawing the secondary structure features\n",
"class HelixPlotter(graphics.FeaturePlotter):\n",
"\n",
" def __init__(self):\n",
" pass\n",
"\n",
" # Check whether this class is applicable for drawing a feature\n",
" def matches(self, feature):\n",
" if feature.key == \"SecStr\":\n",
" if \"sec_str_type\" in feature.qual:\n",
" if feature.qual[\"sec_str_type\"] == \"helix\":\n",
" return True\n",
" return False\n",
"\n",
" # The drawing function itself\n",
" def draw(self, axes, feature, bbox, loc, style_param):\n",
" # Approx. 1 turn per 3.6 residues to resemble natural helix\n",
" n_turns = np.ceil((loc.last - loc.first + 1) / 3.6)\n",
" x_val = np.linspace(0, n_turns * 2*np.pi, 100)\n",
" # Curve ranges from 0.3 to 0.7\n",
" y_val = (-0.4*np.sin(x_val) + 1) / 2\n",
"\n",
" # Transform values for correct location in feature map\n",
" x_val *= bbox.width / (n_turns * 2*np.pi)\n",
" x_val += bbox.x0\n",
" y_val *= bbox.height\n",
" y_val += bbox.y0\n",
"\n",
" # Draw white background to overlay the guiding line\n",
" background = Rectangle(\n",
" bbox.p0, bbox.width, bbox.height, color=\"white\", linewidth=0\n",
" )\n",
" axes.add_patch(background)\n",
" axes.plot(\n",
" x_val, y_val, linewidth=2, color=biotite.colors[\"dimgreen\"]\n",
" )\n",
"\n",
"\n",
"class SheetPlotter(graphics.FeaturePlotter):\n",
"\n",
" def __init__(self, head_width=0.8, tail_width=0.5):\n",
" self._head_width = head_width\n",
" self._tail_width = tail_width\n",
"\n",
" def matches(self, feature):\n",
" if feature.key == \"SecStr\":\n",
" if \"sec_str_type\" in feature.qual:\n",
" if feature.qual[\"sec_str_type\"] == \"sheet\":\n",
" return True\n",
" return False\n",
"\n",
" def draw(self, axes, feature, bbox, loc, style_param):\n",
" x = bbox.x0\n",
" y = bbox.y0 + bbox.height/2\n",
" dx = bbox.width\n",
" dy = 0\n",
"\n",
" if loc.defect & seq.Location.Defect.MISS_RIGHT:\n",
" # If the feature extends into the previous or next line\n",
" # do not draw an arrow head\n",
" draw_head = False\n",
" else:\n",
" draw_head = True\n",
"\n",
" axes.add_patch(biotite.AdaptiveFancyArrow(\n",
" x, y, dx, dy,\n",
" self._tail_width*bbox.height, self._head_width*bbox.height,\n",
" # Create head with 90 degrees tip\n",
" # -> head width/length ratio = 1/2\n",
" head_ratio=0.5, draw_head=draw_head,\n",
" color=biotite.colors[\"orange\"], linewidth=0\n",
" ))\n",
"\n",
"# Converter for the DSSP secondary structure elements\n",
"# to the classical ones\n",
"dssp_to_abc = {\"I\" : \"c\",\n",
" \"S\" : \"c\",\n",
" \"H\" : \"a\",\n",
" \"E\" : \"b\",\n",
" \"G\" : \"c\",\n",
" \"B\" : \"b\",\n",
" \"T\" : \"c\",\n",
" \"C\" : \"c\"}\n",
"\n",
"def visualize_secondary_structure(sse, first_id):\n",
" \"\"\"\n",
" Helper function to convert secondary structure array to annotation\n",
" and visualize it.\n",
" \"\"\"\n",
"\n",
" def _add_sec_str(annotation, first, last, str_type):\n",
" if str_type == \"a\":\n",
" str_type = \"helix\"\n",
" elif str_type == \"b\":\n",
" str_type = \"sheet\"\n",
" else:\n",
" # coil\n",
" return\n",
" feature = seq.Feature(\n",
" \"SecStr\", [seq.Location(first, last)], {\"sec_str_type\" : str_type}\n",
" )\n",
" annotation.add_feature(feature)\n",
"\n",
" # Find the intervals for each secondary structure element\n",
" # and add to annotation\n",
" annotation = seq.Annotation()\n",
" curr_sse = None\n",
" curr_start = None\n",
" for i in range(len(sse)):\n",
" if curr_start is None:\n",
" curr_start = i\n",
" curr_sse = sse[i]\n",
" else:\n",
" if sse[i] != sse[i-1]:\n",
" _add_sec_str(\n",
" annotation, curr_start+first_id, i-1+first_id, curr_sse\n",
" )\n",
" curr_start = i\n",
" curr_sse = sse[i]\n",
" # Add last secondary structure element to annotation\n",
" _add_sec_str(annotation, curr_start+first_id, i+first_id, curr_sse)\n",
"\n",
" fig = plt.figure(figsize=(30.0, 3.0))\n",
" ax = fig.add_subplot(111)\n",
" graphics.plot_feature_map(\n",
" ax, annotation, symbols_per_line=150,\n",
" loc_range=(first_id, first_id+len(sse)),\n",
" feature_plotters=[HelixPlotter(), SheetPlotter()]\n",
" )\n",
" fig.tight_layout()\n",
" return fig, ax\n",
"\n",
"\n",
"def plot_ss8(ss8_string):\n",
" ss3 = np.array([dssp_to_abc[e] for e in ss8_string], dtype=\"U1\")\n",
" _, ax = visualize_secondary_structure(ss3, 1)\n",
" ax.set_xticks([])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Using these functions, we can visualize the secondary structure that we obtained. The alpha helices are represented in green, the beta sheets are represented in orange, and coils are represented by gray lines. \n",
"\n",
"Note: because the secondary structure assignment algorithm is not the same one as the one used by 3D visualization, this differs a bit from the cartoon representations in the 3D assignment."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plot_ss8(protein.secondary_structure)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### `function_annotations`\n",
"\n",
"An `ESMProtein` also contains function annotations derived from [InterPro](https://www.ebi.ac.uk/interpro/). Annotations directly from InterPro contain information about the following [entry types](https://interpro-documentation.readthedocs.io/en/latest/faq.html#what-are-entry-types):\n",
"* Family\n",
"* Domain\n",
"* Homologous superfamily\n",
"* Repeat\n",
"* Site (conserved site, active site, binding site, post-translational modification site)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from esm.utils.types import FunctionAnnotation\n",
"\n",
"interpro_function_annotations = [\n",
" FunctionAnnotation(label=\"IPR050145\", start=1, end=142), # 1 indexed, inclusive;\n",
" FunctionAnnotation(label=\"IPR002048\", start=4, end=75),\n",
" FunctionAnnotation(label=\"IPR002048\", start=77, end=144),\n",
" FunctionAnnotation(label=\"IPR011992\", start=1, end=143),\n",
" FunctionAnnotation(label=\"IPR018247\", start=17, end=29),\n",
" FunctionAnnotation(label=\"IPR018247\", start=53, end=65),\n",
" FunctionAnnotation(label=\"IPR018247\", start=90, end=102),\n",
" FunctionAnnotation(label=\"IPR018247\", start=126, end=138),\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can visualize these InterPro annotations with the following function:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"\n",
"Functions for visualizing InterPro function annotations\n",
"\"\"\"\n",
"\n",
"! pip install dna-features-viewer\n",
"\n",
"from dna_features_viewer import GraphicFeature, GraphicRecord\n",
"from esm.utils.function.interpro import InterProEntryType, InterPro\n",
"from matplotlib import colormaps\n",
"\n",
"def visualize_function_annotations(\n",
" annotations: list[FunctionAnnotation],\n",
" sequence_length: int,\n",
" ax: plt.Axes,\n",
" interpro_ = InterPro(),\n",
"):\n",
" cmap = colormaps[\"tab10\"]\n",
" colors = [cmap(i) for i in range(len(InterProEntryType))]\n",
" type_colors = dict(zip(InterProEntryType, colors))\n",
"\n",
" features = []\n",
" for annotation in annotations:\n",
" if annotation.label in interpro_.entries:\n",
" entry = interpro_.entries[annotation.label]\n",
" label = entry.name\n",
" entry_type = entry.type\n",
" else:\n",
" label = annotation.label\n",
" entry_type = InterProEntryType.UNKNOWN\n",
"\n",
" feature = GraphicFeature(\n",
" start=annotation.start - 1, # one index -> zero index\n",
" end=annotation.end,\n",
" label=label,\n",
" color=type_colors[entry_type],\n",
" strand=None,\n",
" )\n",
" features.append(feature)\n",
"\n",
" record = GraphicRecord(\n",
" sequence=None,\n",
" sequence_length=sequence_length,\n",
" features=features,\n",
" )\n",
"\n",
" record.plot(figure_width=12, plot_sequence=False, ax=ax)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We plot the InterPro annotations below, with colors indicating the entry type of the InterPro annotation "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots(figsize=(20.0, 4.0))\n",
"visualize_function_annotations(\n",
" interpro_function_annotations,\n",
" len(protein),\n",
" ax,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When using our `ESM3` model, we recommend you use keyword annotations, which are keywords in the description of the InterPro entry and associated Gene Ontology terms from [InterPro2GO](https://www.ebi.ac.uk/GOA/InterPro2GO). For instance, for the InterPro entry [IPR011992](https://www.ebi.ac.uk/interpro/entry/InterPro/IPR011992/), the keywords are \"domain pair\", \"hand domain\", \"ef hand\", \"pair\", and \"ef\". For more details regarding how the keywords were computed, please refer to our preprint.\n",
"\n",
"Practically, we can derive keyword annotations from the InterPro annotations with the function below. Each InterPro annotation corresponds to multiple keyword annotation covering the same range."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from esm.tokenization import InterProQuantizedTokenizer\n",
"\n",
"def get_keywords_from_interpro(\n",
" interpro_annotations,\n",
" interpro2keywords=InterProQuantizedTokenizer().interpro2keywords,\n",
"):\n",
" keyword_annotations_list = []\n",
" for interpro_annotation in interpro_annotations:\n",
" keywords = interpro2keywords.get(interpro_annotation.label, [])\n",
" keyword_annotations_list.extend([\n",
" FunctionAnnotation(\n",
" label=keyword,\n",
" start=interpro_annotation.start,\n",
" end=interpro_annotation.end,\n",
" )\n",
" for keyword in keywords\n",
" ])\n",
" return keyword_annotations_list"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"protein.function_annotations = get_keywords_from_interpro(interpro_function_annotations)\n",
"protein.function_annotations"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also visualize the keyword annotations, which all have the same color, indicating it is not a known InterPro entry type."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots(figsize=(20.0, 8.0))\n",
"visualize_function_annotations(\n",
" protein.function_annotations,\n",
" len(protein),\n",
" ax,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### `sasa`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The final input track of `ESMProtein` is the solvent-accessible surface area, or [SASA](https://en.wikipedia.org/wiki/Accessible_surface_area). For each amino acid, this track indicates how much of it is accessible to the solvent. We can compute this by `ProteinChain`'s `sasa` function, which uses biotite's [`sasa`](https://www.biotite-python.org/apidoc/biotite.structure.sasa.html) function under the hood."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"protein.sasa = protein_chain.sasa()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"One way to visualize this track is to represent its values as it varies along the amino acid sequence."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.plot(protein.sasa)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also map these SASA values onto the 3D visualization of the structure, leveraging the fact that we have this protein's 3D coordinates.\n",
"\n",
"First we define which colors map to which values:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from matplotlib import colormaps\n",
"\n",
"cmap = colormaps['cividis']\n",
"clip_sasa_lower = 10\n",
"clip_sasa_upper = 90\n",
"\n",
"def plot_heatmap_legend(\n",
" cmap,\n",
" clip_sasa_lower,\n",
" clip_sasa_upper,\n",
"):\n",
" gradient = np.linspace(0, 1, 256)\n",
" gradient = np.vstack((gradient, gradient))\n",
" _, ax = plt.subplots(figsize=(5, 0.3), dpi=350)\n",
" ax.imshow(gradient, aspect='auto', cmap=cmap)\n",
" ax.text(0.1, -0.3, f'{clip_sasa_lower} or lower', va='center', ha='right', fontsize=7, transform=ax.transAxes)\n",
" ax.text(0.5, -0.3, f'{(clip_sasa_lower + clip_sasa_upper) // 2}', va='center', ha='right', fontsize=7, transform=ax.transAxes)\n",
" ax.text(0.9, -0.3, f'{clip_sasa_upper} or higher', va='center', ha='left', fontsize=7, transform=ax.transAxes)\n",
" ax.set_xticklabels([])\n",
" ax.set_yticklabels([])\n",
" ax.set_xticks([])\n",
" ax.set_yticks([])\n",
" plt.show()\n",
"\n",
"plot_heatmap_legend(cmap, clip_sasa_lower, clip_sasa_upper)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"\n",
"Functions for visualizing SASA as colors on the 3D structure\n",
"\"\"\"\n",
"\n",
"def get_color_strings(\n",
" sasa,\n",
" clip_sasa_lower,\n",
" clip_sasa_upper,\n",
" cmap,\n",
"):\n",
" transformed_sasa = np.clip(sasa, clip_sasa_lower, clip_sasa_upper)\n",
" transformed_sasa = (transformed_sasa - clip_sasa_lower) / (clip_sasa_upper - clip_sasa_lower)\n",
" rgbas = (cmap(transformed_sasa) * 255).astype(int)\n",
"\n",
" return [\n",
" f'rgb({rgba[0]},{rgba[1]},{rgba[2]})'\n",
" for rgba in rgbas\n",
" ] \n",
"\n",
"def visualize_sasa_3D_protein(\n",
" protein,\n",
" clip_sasa_lower=clip_sasa_lower,\n",
" clip_sasa_upper=clip_sasa_upper,\n",
" cmap=cmap,\n",
"):\n",
" pdb_string = protein.to_pdb_string()\n",
" plot_heatmap_legend(cmap, clip_sasa_lower, clip_sasa_upper)\n",
" view = py3Dmol.view(width=400, height=400)\n",
" view.addModel(pdb_string, \"pdb\")\n",
"\n",
" for res_pos, res_color in enumerate(get_color_strings(protein.sasa, clip_sasa_lower, clip_sasa_upper, cmap)):\n",
" view.setStyle({'chain': 'A', 'resi': res_pos+1}, {'cartoon': {'color': res_color}})\n",
" view.zoomTo()\n",
" view.render()\n",
" view.center()\n",
"\n",
" return view"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We visualize SASA on the 3D structure below. Note that the amino acids that are on the inside have lower SASA values, and the amino acids at the surface have higher SASA values."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"visualize_sasa_3D_protein(protein)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We have now covered all the tracks of `ESMProtein`. \n",
"\n",
"We can initialize an `ESMProtein` by providing any of these tracks. For instance, to initialize a protein with the same coordinates as our `protein`, we would do:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"same_structure_protein = ESMProtein(coordinates=protein.coordinates)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"and similarly for any other track."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We hope this helps you get started with our ESM3 models!"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

File diff suppressed because one or more lines are too long

View File

@@ -2,6 +2,8 @@ from esm.models.esm3 import ESM3
from esm.sdk.api import (
ESM3InferenceClient,
ESMProtein,
ESMProteinError,
ESMProteinTensor,
GenerationConfig,
SamplingConfig,
SamplingTrackConfig,
@@ -33,6 +35,20 @@ def main(client: ESM3InferenceClient):
single_step_protein.protein_tensor.sequence = protein.sequence
single_step_protein = client.decode(single_step_protein.protein_tensor)
# Generate with partial sequence.
prompt = (
"___________________________________________________DQATSLRILNNGHAFNVEFDDSQDKAVLK"
"GGPLDGTYRLIQFHFHWGSLDGQGSEHTVDKKKYAAELHLVHWNTKYGDFGKAVQQPDGLAVLGIFLKVGSAKPGLQKVVDVLDSIK"
"TKGKSADFTNFDPRGLLPESLDYWTYPGSLTTPP___________________________________________________________"
)
protein = ESMProtein(sequence=prompt)
protein = client.generate(
protein,
GenerationConfig(track="sequence", num_steps=8, temperature=0.7),
)
assert isinstance(protein, ESMProtein)
print(protein.sequence)
# Folding
protein = get_sample_protein()
sequence_length = len(protein.sequence) # type: ignore
@@ -44,9 +60,10 @@ def main(client: ESM3InferenceClient):
protein,
GenerationConfig(track="structure", schedule="cosine", num_steps=num_steps),
)
assert isinstance(folded_protein, ESMProtein)
folded_protein.to_pdb("./sample_folded.pdb")
# Inverse Folding
# Inverse folding
protein = get_sample_protein()
protein.sequence = None
protein.sasa = None
@@ -55,8 +72,19 @@ def main(client: ESM3InferenceClient):
protein,
GenerationConfig(track="sequence", schedule="cosine", num_steps=num_steps),
)
assert isinstance(inv_folded_protein, ESMProtein)
inv_folded_protein.to_pdb("./sample_inv_folded.pdb")
# Function prediction
protein = get_sample_protein()
protein.function_annotations = None
protein_with_function = client.generate(
protein,
GenerationConfig(track="function", schedule="cosine", num_steps=num_steps),
)
assert isinstance(protein_with_function, ESMProtein)
print(protein_with_function.function_annotations)
# Chain of Thought (Function -> Secondary Structure -> Structure -> Sequence)
cot_protein = get_sample_protein()
cot_protein.sequence = "_" * len(cot_protein.sequence) # type: ignore
@@ -68,9 +96,51 @@ def main(client: ESM3InferenceClient):
cot_protein_tensor,
GenerationConfig(track=cot_track, schedule="cosine", num_steps=10),
)
assert isinstance(cot_protein_tensor, ESMProteinTensor)
cot_protein = client.decode(cot_protein_tensor)
assert isinstance(cot_protein, ESMProtein)
cot_protein.to_pdb("./sample_cot.pdb")
# Batch examples.
# Batch generation.
prompts = [ESMProtein(sequence=("_" * (10 + 2 * i))) for i in range(5)]
configs = [
GenerationConfig(track="sequence", schedule="cosine", num_steps=(i + 1))
for i in range(5)
]
proteins = client.batch_generate(prompts, configs)
# Batch folding.
# Take the list of proteins batch generated from last step.
configs = [
GenerationConfig(track="structure", schedule="cosine", num_steps=(i + 1))
for i in range(5)
]
# Generate again for the structure track.
proteins = client.batch_generate(proteins, configs)
# Now write sequence and structure to PDB files.
for i, p in enumerate(proteins):
assert isinstance(p, ESMProtein)
p.to_pdb(f"./batch_gen_{i}.pdb")
# Batch generation returns ESMProteinError for specific prompts that have issues.
prompts = [ESMProtein(sequence=("_" * (10 + 2 * i))) for i in range(5)]
# Mock error situation. The third prompt has no masks to be sampled.
prompts[2].sequence = "ANTVPYQ"
configs = [
GenerationConfig(track="sequence", schedule="cosine", num_steps=(i + 1))
for i in range(5)
]
proteins = client.batch_generate(prompts, configs)
# Should still get results. But third result is a ESMProteinError.
for i, p in enumerate(proteins):
if i == 2:
assert isinstance(p, ESMProteinError)
else:
assert isinstance(p, ESMProtein)
if __name__ == "__main__":
main(ESM3.from_pretrained("esm3_sm_open_v1"))

View File

@@ -34,6 +34,7 @@ dependencies = [
"brotli",
"attrs",
"pandas",
"cloudpathlib",
]
[tool.setuptools]