mirror of
https://github.com/evolutionaryscale/esm.git
synced 2026-06-04 17:14:23 +08:00
v0.2rc1
This commit is contained in:
@@ -0,0 +1 @@
|
||||
__version__ = "0.2rc1"
|
||||
|
||||
@@ -59,12 +59,19 @@ class MultiHeadAttention(nn.Module):
|
||||
reshaper, (query_BLD, key_BLD, value_BLD)
|
||||
)
|
||||
|
||||
# Where True, enable participation in attention.
|
||||
mask_BLL = seq_id.unsqueeze(-1) == seq_id.unsqueeze(-2)
|
||||
mask_BHLL = mask_BLL.unsqueeze(1)
|
||||
if seq_id is not None:
|
||||
# Where True, enable participation in attention.
|
||||
mask_BLL = seq_id.unsqueeze(-1) == seq_id.unsqueeze(-2)
|
||||
mask_BHLL = mask_BLL.unsqueeze(1)
|
||||
|
||||
context_BHLD = F.scaled_dot_product_attention(
|
||||
query_BHLD, key_BHLD, value_BHLD, mask_BHLL
|
||||
)
|
||||
context_BHLD = F.scaled_dot_product_attention(
|
||||
query_BHLD, key_BHLD, value_BHLD, mask_BHLL
|
||||
)
|
||||
else:
|
||||
# Shortcut, if we don't use attention biases then torch
|
||||
# will autoselect flashattention as the implementation
|
||||
context_BHLD = F.scaled_dot_product_attention(
|
||||
query_BHLD, key_BHLD, value_BHLD
|
||||
)
|
||||
context_BLD = einops.rearrange(context_BHLD, "b h s d -> b s (h d)")
|
||||
return self.out_proj(context_BLD)
|
||||
|
||||
@@ -50,6 +50,8 @@ class GeometricReasoningOriginalImpl(nn.Module):
|
||||
self.rotation_scale_per_head = nn.Parameter(torch.zeros((self.v_heads)))
|
||||
|
||||
def forward(self, s, affine, affine_mask, sequence_id, chain_id):
|
||||
if sequence_id is None:
|
||||
sequence_id = torch.zeros_like(s[..., 0], dtype=torch.int64)
|
||||
attn_bias = sequence_id.unsqueeze(-1) == sequence_id.unsqueeze(-2)
|
||||
attn_bias = attn_bias.unsqueeze(1).float()
|
||||
attn_bias = attn_bias.masked_fill(
|
||||
|
||||
@@ -83,10 +83,6 @@ class TransformerStack(nn.Module):
|
||||
pre_norm: The embedding of shape (batch_size, sequence_length, d_model).
|
||||
"""
|
||||
*batch_dims, _ = x.shape
|
||||
if sequence_id is None:
|
||||
sequence_id = torch.ones(
|
||||
size=batch_dims, dtype=torch.int64, device=x.device
|
||||
)
|
||||
if chain_id is None:
|
||||
chain_id = torch.ones(size=batch_dims, dtype=torch.int64, device=x.device)
|
||||
for block in self.blocks:
|
||||
|
||||
@@ -26,9 +26,7 @@ from esm.sdk.api import (
|
||||
ForwardTrackData,
|
||||
GenerationConfig,
|
||||
ProteinType,
|
||||
ReturnLogitsConfig,
|
||||
SamplingConfig,
|
||||
SamplingTrackConfig,
|
||||
)
|
||||
from esm.tokenization import get_model_tokenizers
|
||||
from esm.utils import encoding
|
||||
@@ -36,15 +34,16 @@ from esm.utils.constants import esm3 as C
|
||||
from esm.utils.constants.models import ESM3_OPEN_SMALL
|
||||
from esm.utils.decoding import decode_protein_tensor
|
||||
from esm.utils.generation import (
|
||||
_batch_forward,
|
||||
_sample_per_prompt,
|
||||
_slice_tensor_dataclass,
|
||||
iterative_sampling_raw,
|
||||
iterative_sampling_tokens,
|
||||
)
|
||||
from esm.utils.misc import rbf
|
||||
from esm.utils.sampling import (
|
||||
_BatchedESMProteinTensor,
|
||||
get_default_sampling_config,
|
||||
sample_function_logits,
|
||||
sample_logits,
|
||||
sample_residue_annotation_logits,
|
||||
)
|
||||
from esm.utils.structure.affine3d import (
|
||||
build_affine3d_from_coordinates,
|
||||
@@ -222,9 +221,9 @@ class ESM3(nn.Module, ESM3InferenceClient):
|
||||
self.structure_decoder_name = structure_decoder_name
|
||||
self.function_decoder_name = function_decoder_name
|
||||
|
||||
self.structure_encoder: StructureTokenEncoder | None = None # type: ignore
|
||||
self.structure_decoder: StructureTokenDecoder | None = None # type: ignore
|
||||
self.function_decoder: FunctionTokenDecoder | None = None # type: ignore
|
||||
self.structure_encoder: StructureTokenEncoder | None = None
|
||||
self.structure_decoder: StructureTokenDecoder | None = None
|
||||
self.function_decoder: FunctionTokenDecoder | None = None
|
||||
|
||||
self.tokenizers = get_model_tokenizers(ESM3_OPEN_SMALL)
|
||||
|
||||
@@ -232,29 +231,40 @@ class ESM3(nn.Module, ESM3InferenceClient):
|
||||
def from_pretrained(
|
||||
cls,
|
||||
model_name: str = ESM3_OPEN_SMALL,
|
||||
device: torch.device | str = "cpu",
|
||||
device: torch.device | None = None,
|
||||
) -> ESM3:
|
||||
from esm.pretrained import load_local_model
|
||||
|
||||
if model_name not in [ESM3_OPEN_SMALL]:
|
||||
raise ValueError(f"Model name {model_name} is not a valid ESM3 model name.")
|
||||
model: ESM3 = load_local_model(model_name, device=device) # type: ignore
|
||||
if device is None:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = load_local_model(model_name, device=device)
|
||||
if device.type != "cpu":
|
||||
model = model.to(torch.bfloat16)
|
||||
assert isinstance(model, ESM3)
|
||||
return model
|
||||
|
||||
def get_structure_token_encoder(self) -> StructureTokenEncoder:
|
||||
if self.structure_encoder is None:
|
||||
self.structure_encoder = self.load_model(self.structure_encoder_name) # type: ignore
|
||||
return self.structure_encoder # type: ignore
|
||||
model = self.load_model(self.structure_encoder_name)
|
||||
assert isinstance(model, StructureTokenEncoder)
|
||||
self.structure_encoder = model
|
||||
return self.structure_encoder
|
||||
|
||||
def get_structure_token_decoder(self) -> StructureTokenDecoder:
|
||||
if self.structure_decoder is None:
|
||||
self.structure_decoder = self.load_model(self.structure_decoder_name) # type: ignore
|
||||
return self.structure_decoder # type: ignore
|
||||
model = self.load_model(self.structure_decoder_name)
|
||||
assert isinstance(model, StructureTokenDecoder)
|
||||
self.structure_decoder = model
|
||||
return self.structure_decoder
|
||||
|
||||
def get_function_token_decoder(self) -> FunctionTokenDecoder:
|
||||
if self.function_decoder is None:
|
||||
self.function_decoder = self.load_model(self.function_decoder_name) # type: ignore
|
||||
return self.function_decoder # type: ignore
|
||||
model = self.load_model(self.function_decoder_name)
|
||||
assert isinstance(model, FunctionTokenDecoder)
|
||||
self.function_decoder = model
|
||||
return self.function_decoder
|
||||
|
||||
def load_model(self, model_name: str):
|
||||
# Lazy import from pretrained
|
||||
@@ -324,12 +334,11 @@ class ESM3(nn.Module, ESM3InferenceClient):
|
||||
torch.full((1, L), tok, dtype=torch.long, device=device) if x is None else x
|
||||
)
|
||||
sequence_tokens = defaults(sequence_tokens, t.sequence.mask_token_id)
|
||||
ss8_tokens = defaults(ss8_tokens, C.SS8_UNK_TOKEN)
|
||||
sasa_tokens = defaults(sasa_tokens, C.SASA_UNK_TOKEN)
|
||||
ss8_tokens = defaults(ss8_tokens, C.SS8_PAD_TOKEN)
|
||||
sasa_tokens = defaults(sasa_tokens, C.SASA_PAD_TOKEN)
|
||||
average_plddt = defaults(average_plddt, 1).float()
|
||||
per_res_plddt = defaults(per_res_plddt, 0).float()
|
||||
chain_id = defaults(chain_id, 0)
|
||||
sequence_id = defaults(sequence_id, 0)
|
||||
|
||||
if residue_annotation_tokens is None:
|
||||
residue_annotation_tokens = torch.full(
|
||||
@@ -384,10 +393,39 @@ class ESM3(nn.Module, ESM3InferenceClient):
|
||||
|
||||
# The following methods are for the ESM3InferenceClient interface
|
||||
def generate(self, input: ProteinType, config: GenerationConfig) -> ProteinType:
|
||||
if isinstance(input, ESMProtein):
|
||||
return iterative_sampling_raw(self, input, config)
|
||||
elif isinstance(input, ESMProteinTensor):
|
||||
return iterative_sampling_tokens(self, input, config, self.tokenizers)
|
||||
"""Wrap around batched generation."""
|
||||
proteins = self.batch_generate([input], [config])
|
||||
assert len(proteins) == 1
|
||||
return proteins[0]
|
||||
|
||||
def batch_generate(
|
||||
self, inputs: list[ProteinType], configs: list[GenerationConfig]
|
||||
) -> list[ProteinType]:
|
||||
assert len(inputs) == len(
|
||||
configs
|
||||
), "Must have the same number of prompts and configs."
|
||||
|
||||
if inputs is []:
|
||||
# Nothing to do.
|
||||
return []
|
||||
|
||||
# Make sure prompts are of the same type.
|
||||
t = type(inputs[0])
|
||||
for i in range(1, len(inputs)):
|
||||
assert isinstance(inputs[i], t), (
|
||||
"Prompts must have the same type. Got "
|
||||
f"{t.__name__ and type(inputs[i]).__name__} instead."
|
||||
)
|
||||
|
||||
if isinstance(inputs[0], ESMProtein):
|
||||
return iterative_sampling_raw(self, inputs, configs) # type: ignore
|
||||
elif isinstance(inputs[0], ESMProteinTensor):
|
||||
return iterative_sampling_tokens(
|
||||
self,
|
||||
inputs, # type: ignore
|
||||
configs,
|
||||
self.tokenizers, # type: ignore
|
||||
)
|
||||
else:
|
||||
raise ValueError("Input must be an ESMProtein or ESMProteinTensor")
|
||||
|
||||
@@ -486,6 +524,7 @@ class ESM3(nn.Module, ESM3InferenceClient):
|
||||
def _forward(
|
||||
self, input: ESMProteinTensor, config: ForwardConfig = ForwardConfig()
|
||||
) -> ForwardOutput:
|
||||
device = torch.device(input.device)
|
||||
# Default plddt conditioning for inference. 1s where coordinates are provided.
|
||||
if input.coordinates is None:
|
||||
per_res_plddt = None
|
||||
@@ -493,7 +532,12 @@ class ESM3(nn.Module, ESM3InferenceClient):
|
||||
# 1.0 if all coordinates at specific indices have valid non-nan values.
|
||||
per_res_plddt = input.coordinates.isfinite().all(dim=-1).any(dim=-1).float()
|
||||
|
||||
with torch.no_grad() if self.eval else contextlib.nullcontext():
|
||||
with (
|
||||
torch.no_grad(), # Assume no gradients for now...
|
||||
torch.autocast(enabled=True, device_type=device.type, dtype=torch.bfloat16) # type: ignore
|
||||
if device.type == "cuda"
|
||||
else contextlib.nullcontext(),
|
||||
):
|
||||
output = self.forward(
|
||||
sequence_tokens=input.sequence,
|
||||
structure_tokens=input.structure,
|
||||
@@ -508,31 +552,32 @@ class ESM3(nn.Module, ESM3InferenceClient):
|
||||
sequence_id=None,
|
||||
)
|
||||
|
||||
if config.return_logits:
|
||||
logits = ForwardTrackData(
|
||||
sequence=output.sequence_logits,
|
||||
structure=output.structure_logits,
|
||||
secondary_structure=output.secondary_structure_logits,
|
||||
sasa=output.sasa_logits,
|
||||
function=output.function_logits,
|
||||
)
|
||||
else:
|
||||
logits = None
|
||||
output = ESMOutput(
|
||||
**{k: v.to(device).to(torch.float32) for k, v in vars(output).items()}
|
||||
)
|
||||
|
||||
return ForwardOutput(
|
||||
logits=logits,
|
||||
residue_annotation_logits=output.residue_logits,
|
||||
embeddings=output.embeddings if config.return_embeddings else None,
|
||||
if config.return_logits:
|
||||
logits = ForwardTrackData(
|
||||
sequence=output.sequence_logits,
|
||||
structure=output.structure_logits,
|
||||
secondary_structure=output.secondary_structure_logits,
|
||||
sasa=output.sasa_logits,
|
||||
function=output.function_logits,
|
||||
)
|
||||
else:
|
||||
logits = None
|
||||
|
||||
return ForwardOutput(
|
||||
logits=logits,
|
||||
residue_annotation_logits=output.residue_logits,
|
||||
embeddings=output.embeddings if config.return_embeddings else None,
|
||||
)
|
||||
|
||||
def forward_and_sample(
|
||||
self, input: ESMProteinTensor, sampling_configuration: SamplingConfig
|
||||
) -> ForwardAndSampleOutput:
|
||||
protein_tensor = attr.evolve(input) # Make a copy
|
||||
|
||||
def maybe_clone(x: torch.Tensor | None) -> torch.Tensor | None:
|
||||
return x.clone() if x is not None else None
|
||||
|
||||
device = next(self.parameters()).device
|
||||
|
||||
sampling_config = sampling_configuration
|
||||
@@ -551,249 +596,20 @@ class ESM3(nn.Module, ESM3InferenceClient):
|
||||
getattr(default_protein_tensor, track.name, None),
|
||||
)
|
||||
|
||||
# Preprocessing
|
||||
sequence_length: int = -1
|
||||
for track in [
|
||||
"sequence",
|
||||
"structure",
|
||||
"secondary_structure",
|
||||
"sasa",
|
||||
"function",
|
||||
"residue_annotations",
|
||||
]:
|
||||
input_tensor: torch.Tensor | None = getattr(protein_tensor, track, None)
|
||||
if input_tensor is not None:
|
||||
# Add batch dimension if necessary
|
||||
if track in ["sequence", "structure", "secondary_structure", "sasa"]:
|
||||
if len(input_tensor.size()) == 1:
|
||||
input_tensor = input_tensor.unsqueeze(0) # (L,) -> (1, L)
|
||||
elif track in ["function", "residue_annotations"]:
|
||||
if len(input_tensor.size()) == 2:
|
||||
input_tensor = input_tensor.unsqueeze(0) # (L, O) -> (1, L, O)
|
||||
|
||||
# Check length consistency
|
||||
if sequence_length == -1:
|
||||
sequence_length = input_tensor.size(1)
|
||||
else:
|
||||
if input_tensor.size(1) != sequence_length:
|
||||
raise ValueError(
|
||||
f"Length mismatch for track {track}. Expected {sequence_length}, got {input_tensor.size(1)}"
|
||||
)
|
||||
|
||||
# Move input tensor to model device
|
||||
input_tensor = input_tensor.to(device)
|
||||
setattr(protein_tensor, track, input_tensor)
|
||||
|
||||
if protein_tensor.coordinates is not None:
|
||||
coordinates = protein_tensor.coordinates
|
||||
if len(coordinates.size()) == 3:
|
||||
coordinates = coordinates.unsqueeze(0)
|
||||
protein_tensor.coordinates = coordinates.to(device)
|
||||
sequence_length = coordinates.size(1)
|
||||
|
||||
if sequence_length == -1:
|
||||
if len(protein_tensor) <= 0:
|
||||
raise ValueError("No input data provided")
|
||||
|
||||
# Forward pass
|
||||
forward_output = self._forward(
|
||||
protein_tensor,
|
||||
ForwardConfig(
|
||||
ReturnLogitsConfig(
|
||||
sequence=True,
|
||||
structure=True,
|
||||
secondary_structure=True,
|
||||
sasa=True,
|
||||
function=True,
|
||||
residue_annotations=True,
|
||||
),
|
||||
return_embeddings=True,
|
||||
),
|
||||
# Move input protein to proper device.
|
||||
batched_protein = _BatchedESMProteinTensor.from_protein_tensor(protein_tensor)
|
||||
batched_protein.to(device)
|
||||
|
||||
forward_output: ForwardOutput = _batch_forward(self, batched_protein)
|
||||
forward_and_sample_out: ForwardAndSampleOutput = _sample_per_prompt(
|
||||
batched_protein,
|
||||
forward_output,
|
||||
sampling_config,
|
||||
self.tokenizers,
|
||||
)
|
||||
|
||||
# Sampling
|
||||
tokens_dir = {}
|
||||
track_sampling_metadata_dir: dict[str, dict | None] = {}
|
||||
for track in ["sequence", "structure", "secondary_structure", "sasa"]:
|
||||
config = getattr(sampling_config, track)
|
||||
if config is None:
|
||||
tokens_dir[track] = maybe_clone(getattr(input, track))
|
||||
continue
|
||||
sampling_metadata = self._sample_track(
|
||||
logits=getattr(forward_output.logits, track)[0, ...],
|
||||
tokens=getattr(protein_tensor, track)[0, ...],
|
||||
sampling_track_config=config,
|
||||
mask_idx=getattr(self.tokenizers, track).mask_token_id,
|
||||
)
|
||||
tokens_dir[track] = sampling_metadata.pop("sampled_tokens") # (L,)
|
||||
track_sampling_metadata_dir[track] = sampling_metadata
|
||||
|
||||
# Sample function and residue annotations separately
|
||||
config = getattr(sampling_config, "function")
|
||||
if config is None:
|
||||
tokens_dir["function"] = maybe_clone(getattr(input, "function"))
|
||||
tokens_dir["residue_annotations"] = maybe_clone(
|
||||
getattr(input, "residue_annotations")
|
||||
)
|
||||
else:
|
||||
sampling_metadata = self._sample_function_track(
|
||||
tokens=getattr(protein_tensor, "function")[0, ...],
|
||||
logits=getattr(forward_output.logits, "function")[0, ...],
|
||||
sampling_track_config=config,
|
||||
)
|
||||
tokens_dir["function"] = sampling_metadata.pop("sampled_tokens") # (L, D)
|
||||
track_sampling_metadata_dir["function"] = sampling_metadata
|
||||
|
||||
sampled_tokens, _ = sample_residue_annotation_logits(
|
||||
logits=forward_output.residue_annotation_logits[0, ...] # type: ignore
|
||||
)
|
||||
tokens_dir["residue_annotations"] = sampled_tokens # (L, MAX_R)
|
||||
|
||||
# Format output
|
||||
forward_and_sample_output_dir = {}
|
||||
forward_and_sample_output_dir["protein_tensor"] = ESMProteinTensor(**tokens_dir)
|
||||
for property in [
|
||||
"entropy",
|
||||
"prob",
|
||||
"logprob",
|
||||
"top_prob",
|
||||
"topk_logprob",
|
||||
"topk_tokens",
|
||||
]:
|
||||
is_all_none = True
|
||||
forward_track_data_dir = {}
|
||||
for track in track_sampling_metadata_dir.keys():
|
||||
values = track_sampling_metadata_dir[track]
|
||||
if values is not None and values.get(property, None) is not None:
|
||||
forward_track_data_dir[track] = values.get(property, None)
|
||||
is_all_none = False
|
||||
if not is_all_none:
|
||||
forward_and_sample_output_dir[property] = ForwardTrackData(
|
||||
**forward_track_data_dir
|
||||
)
|
||||
else:
|
||||
forward_and_sample_output_dir[property] = None
|
||||
|
||||
perres_embed = (
|
||||
forward_output.embeddings[0] # type: ignore
|
||||
if sampling_configuration.return_per_residue_embeddings
|
||||
else None
|
||||
)
|
||||
mean_embedding = (
|
||||
forward_output.embeddings[0].mean(0) # type: ignore
|
||||
if sampling_configuration.return_mean_embedding
|
||||
else None
|
||||
)
|
||||
|
||||
return ForwardAndSampleOutput(
|
||||
per_residue_embedding=perres_embed,
|
||||
mean_embedding=mean_embedding,
|
||||
**forward_and_sample_output_dir,
|
||||
)
|
||||
|
||||
def _sample_track(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
tokens: torch.Tensor,
|
||||
sampling_track_config: SamplingTrackConfig,
|
||||
mask_idx: int,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
# Sample in all positions
|
||||
temperature = sampling_track_config.temperature
|
||||
sampled_tokens = sample_logits(
|
||||
logits, temperature=temperature, top_p=sampling_track_config.top_p
|
||||
)
|
||||
log_probs = logits.log_softmax(-1)
|
||||
|
||||
# Do not sample at BOS and EOS tokens
|
||||
sampling_mask = torch.ones_like(tokens, dtype=torch.bool) # (L, )
|
||||
sampling_mask[0] = False
|
||||
sampling_mask[-1] = False
|
||||
|
||||
# Do not sample at special token positions but allow sampling at mask token
|
||||
special_minus_mask = list(set(sampling_track_config.invalid_ids) - {mask_idx})
|
||||
if len(special_minus_mask) > 0:
|
||||
special_tokens = torch.tensor(special_minus_mask, device=tokens.device)
|
||||
assert special_tokens.numel() > 0
|
||||
sampling_mask = sampling_mask & (
|
||||
tokens[..., None] != special_tokens[None, :]
|
||||
).all(-1)
|
||||
|
||||
# Keep only samples from masked positions (if specified)
|
||||
if sampling_track_config.only_sample_masked_tokens:
|
||||
masked_tokens = tokens == mask_idx
|
||||
sampling_mask = sampling_mask & masked_tokens
|
||||
sampled_tokens = torch.where(sampling_mask, sampled_tokens, tokens)
|
||||
|
||||
return self._compute_track_metadata(
|
||||
sampled_tokens,
|
||||
log_probs,
|
||||
sampling_mask,
|
||||
top_k=sampling_track_config.topk_logprobs,
|
||||
)
|
||||
|
||||
def _sample_function_track(
|
||||
self,
|
||||
tokens: torch.Tensor,
|
||||
logits: torch.Tensor,
|
||||
sampling_track_config: SamplingTrackConfig,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
# Do not sample at BOS and EOS tokens
|
||||
sampling_mask = torch.ones_like(tokens, dtype=torch.bool)
|
||||
sampling_mask[0] = False
|
||||
sampling_mask[-1] = False
|
||||
|
||||
sampled_tokens, probs = sample_function_logits(
|
||||
logits,
|
||||
self.tokenizers.function,
|
||||
top_p=sampling_track_config.top_p,
|
||||
temperature=sampling_track_config.temperature,
|
||||
)
|
||||
|
||||
if sampling_track_config.only_sample_masked_tokens:
|
||||
raise ValueError(
|
||||
"Sampling only masked tokens is undefined for function tokens."
|
||||
)
|
||||
|
||||
sampled_tokens = torch.where(sampling_mask, sampled_tokens, tokens) # (L, D)
|
||||
|
||||
return self._compute_track_metadata(
|
||||
sampled_tokens,
|
||||
probs,
|
||||
sampling_mask,
|
||||
top_k=sampling_track_config.topk_logprobs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _compute_track_metadata(
|
||||
sampled_tokens: torch.Tensor,
|
||||
log_probs: torch.Tensor,
|
||||
sampling_mask: torch.Tensor,
|
||||
top_k: int,
|
||||
) -> dict:
|
||||
probs = torch.exp(log_probs) # (B, L)
|
||||
entropy = torch.distributions.Categorical(probs=probs).entropy() # (B, L)
|
||||
|
||||
# Only compute probabilities for sampled tokens
|
||||
sampled_logprob = torch.zeros_like(
|
||||
sampled_tokens, dtype=torch.float32
|
||||
) # (B, L)
|
||||
sampled_tokens_valid = sampled_tokens[sampling_mask]
|
||||
sampled_log_probs_valid = log_probs[sampling_mask, sampled_tokens_valid]
|
||||
sampled_logprob[sampling_mask] = sampled_log_probs_valid
|
||||
|
||||
# Calculate extra metadata
|
||||
sampled_prob = torch.exp(sampled_logprob)
|
||||
top_prob = torch.max(probs, dim=-1).values
|
||||
topk_logprobs, topk_tokens = torch.topk(log_probs, top_k, dim=-1)
|
||||
topk_logprobs = None if top_k == 0 else topk_logprobs
|
||||
topk_tokens = None if top_k == 0 else topk_tokens
|
||||
|
||||
return {
|
||||
"entropy": entropy,
|
||||
"sampled_tokens": sampled_tokens,
|
||||
"prob": sampled_prob,
|
||||
"logprob": sampled_logprob,
|
||||
"top_prob": top_prob,
|
||||
"topk_logprob": topk_logprobs,
|
||||
"topk_tokens": topk_tokens,
|
||||
}
|
||||
# There is only 1 prompt to sample for.
|
||||
return _slice_tensor_dataclass(forward_and_sample_out, 0)
|
||||
|
||||
@@ -8,6 +8,7 @@ import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from cloudpathlib import AnyPath
|
||||
|
||||
from esm.layers.regression_head import RegressionHead
|
||||
from esm.layers.transformer_stack import TransformerStack
|
||||
@@ -70,7 +71,7 @@ class FunctionTokenDecoder(nn.Module):
|
||||
}
|
||||
assert len(self.interpro_ids) == config.num_interpro_classes
|
||||
|
||||
with open(config.keyword_vocabulary_path, "r") as f:
|
||||
with AnyPath(config.keyword_vocabulary_path).open("r") as f:
|
||||
self.keywords_vocabulary: list[str] = list(f.read().strip().split("\n"))
|
||||
assert len(self.keywords_vocabulary) == config.keyword_vocabulary_size
|
||||
|
||||
@@ -245,8 +246,8 @@ class FunctionTokenDecoder(nn.Module):
|
||||
interpro_id = self.interpro_ids[class_index]
|
||||
annotation = FunctionAnnotation(
|
||||
label=interpro_id,
|
||||
start=position_index + 1, # zero-index -> one-index inclusive
|
||||
end=position_index + 1, # zero-index -> one-index inclusive
|
||||
start=position_index, # one-index inclusive (BOS shifts indexes +1)
|
||||
end=position_index, # one-index inclusive
|
||||
)
|
||||
annotations.append(annotation)
|
||||
|
||||
@@ -300,8 +301,8 @@ class FunctionTokenDecoder(nn.Module):
|
||||
for range_ in merge_ranges(ranges):
|
||||
annotation = FunctionAnnotation(
|
||||
label=keyword,
|
||||
start=range_.start + 1, # zero-index -> one-index
|
||||
end=range_.stop + 1 - 1, # zero-index excl -> one-index incl
|
||||
start=range_.start, # one-index inclusive (BOS shifts indexes +1)
|
||||
end=range_.stop - 1, # one-index exclusive -> one-index inclusive
|
||||
)
|
||||
annotations.append(annotation)
|
||||
|
||||
@@ -332,8 +333,8 @@ def _merge_annotations(
|
||||
for range_ in merged_ranges:
|
||||
annotation = FunctionAnnotation(
|
||||
label=label,
|
||||
start=range_.start + 1, # zero-index -> one-index
|
||||
end=range_.stop - 1, # zero-index excl -> one-index incl
|
||||
start=range_.start, # one-index inclusive (BOS shifts indexes +1)
|
||||
end=range_.stop - 1, # one-index exclusive -> one-index inclusive
|
||||
)
|
||||
merged.append(annotation)
|
||||
return merged
|
||||
|
||||
@@ -235,12 +235,14 @@ class StructureTokenEncoder(nn.Module):
|
||||
knn_sequence_id = (
|
||||
node_gather(sequence_id.unsqueeze(-1), knn_edges).view(-1, E)
|
||||
if sequence_id is not None
|
||||
else torch.zeros(L, E, dtype=torch.int64, device=coords.device)
|
||||
else torch.zeros(B * L, E, dtype=torch.int64, device=coords.device)
|
||||
)
|
||||
knn_affine_mask = node_gather(affine_mask.unsqueeze(-1), knn_edges).view(
|
||||
-1, E
|
||||
)
|
||||
knn_chain_id = torch.zeros(L, E, dtype=torch.int64, device=coords.device)
|
||||
knn_chain_id = torch.zeros(
|
||||
B * L, E, dtype=torch.int64, device=coords.device
|
||||
)
|
||||
|
||||
if residue_index is None:
|
||||
res_idxs = knn_edges.view(-1, E)
|
||||
|
||||
@@ -21,8 +21,8 @@ ModelBuilder = Callable[[torch.device | str], nn.Module]
|
||||
|
||||
|
||||
def ESM3_sm_open_v0(device: torch.device | str = "cpu"):
|
||||
model = (
|
||||
ESM3(
|
||||
with torch.device(device):
|
||||
model = ESM3(
|
||||
d_model=1536,
|
||||
n_heads=24,
|
||||
v_heads=256,
|
||||
@@ -30,10 +30,7 @@ def ESM3_sm_open_v0(device: torch.device | str = "cpu"):
|
||||
structure_encoder_name=ESM3_STRUCTURE_ENCODER_V0,
|
||||
structure_decoder_name=ESM3_STRUCTURE_DECODER_V0,
|
||||
function_decoder_name=ESM3_FUNCTION_DECODER_V0,
|
||||
)
|
||||
.to(device)
|
||||
.eval()
|
||||
)
|
||||
).eval()
|
||||
state_dict = torch.load(
|
||||
data_root() / "data/weights/esm3_sm_open_v1.pth", map_location=device
|
||||
)
|
||||
@@ -42,13 +39,10 @@ def ESM3_sm_open_v0(device: torch.device | str = "cpu"):
|
||||
|
||||
|
||||
def ESM3_structure_encoder_v0(device: torch.device | str = "cpu"):
|
||||
model = (
|
||||
StructureTokenEncoder(
|
||||
with torch.device(device):
|
||||
model = StructureTokenEncoder(
|
||||
d_model=1024, n_heads=1, v_heads=128, n_layers=2, d_out=128, n_codes=4096
|
||||
)
|
||||
.to(device)
|
||||
.eval()
|
||||
)
|
||||
).eval()
|
||||
state_dict = torch.load(
|
||||
data_root() / "data/weights/esm3_structure_encoder_v0.pth", map_location=device
|
||||
)
|
||||
@@ -57,9 +51,8 @@ def ESM3_structure_encoder_v0(device: torch.device | str = "cpu"):
|
||||
|
||||
|
||||
def ESM3_structure_decoder_v0(device: torch.device | str = "cpu"):
|
||||
model = (
|
||||
StructureTokenDecoder(d_model=1280, n_heads=20, n_layers=30).to(device).eval()
|
||||
)
|
||||
with torch.device(device):
|
||||
model = StructureTokenDecoder(d_model=1280, n_heads=20, n_layers=30).eval()
|
||||
state_dict = torch.load(
|
||||
data_root() / "data/weights/esm3_structure_decoder_v0.pth", map_location=device
|
||||
)
|
||||
@@ -68,7 +61,8 @@ def ESM3_structure_decoder_v0(device: torch.device | str = "cpu"):
|
||||
|
||||
|
||||
def ESM3_function_decoder_v0(device: torch.device | str = "cpu"):
|
||||
model = FunctionTokenDecoder().to(device).eval()
|
||||
with torch.device(device):
|
||||
model = FunctionTokenDecoder().eval()
|
||||
state_dict = torch.load(
|
||||
data_root() / "data/weights/esm3_function_decoder_v0.pth", map_location=device
|
||||
)
|
||||
@@ -84,7 +78,9 @@ LOCAL_MODEL_REGISTRY: dict[str, ModelBuilder] = {
|
||||
}
|
||||
|
||||
|
||||
def load_local_model(model_name: str, device: torch.device | str = "cpu") -> nn.Module:
|
||||
def load_local_model(
|
||||
model_name: str, device: torch.device = torch.device("cpu")
|
||||
) -> nn.Module:
|
||||
if model_name not in LOCAL_MODEL_REGISTRY:
|
||||
raise ValueError(f"Model {model_name} not found in local model registry.")
|
||||
return LOCAL_MODEL_REGISTRY[model_name](device)
|
||||
|
||||
11
esm/sdk/__init__.py
Normal file
11
esm/sdk/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import os
|
||||
|
||||
from esm.sdk.forge import ESM3ForgeInferenceClient
|
||||
|
||||
|
||||
def client(
|
||||
model="esm3-sm-open-v1",
|
||||
url="https://forge.evolutionaryscale.ai",
|
||||
token=os.environ.get("ESM_API_KEY", ""),
|
||||
):
|
||||
return ESM3ForgeInferenceClient(model, url, token)
|
||||
@@ -1,11 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from typing import Sequence, TypeVar
|
||||
from typing import Sequence
|
||||
|
||||
import attr
|
||||
import torch
|
||||
from attr import define
|
||||
from attr import asdict, define
|
||||
|
||||
from esm.tokenization import (
|
||||
TokenizerCollectionProtocol,
|
||||
@@ -21,9 +21,13 @@ from esm.utils.types import (
|
||||
)
|
||||
|
||||
|
||||
class ProteinType(ABC):
|
||||
...
|
||||
|
||||
|
||||
## Basic Types
|
||||
@define
|
||||
class ESMProtein:
|
||||
class ESMProtein(ProteinType):
|
||||
# Tracks
|
||||
sequence: str | None = None
|
||||
secondary_structure: str | None = None
|
||||
@@ -101,13 +105,15 @@ class ESMProtein:
|
||||
entity_id=None,
|
||||
residue_index=None,
|
||||
insertion_code=None,
|
||||
confidence=None if self.plddt is None else self.plddt.detach().cpu().numpy(),
|
||||
confidence=None
|
||||
if self.plddt is None
|
||||
else self.plddt.detach().cpu().numpy(),
|
||||
)
|
||||
return protein_chain
|
||||
|
||||
|
||||
@define
|
||||
class ESMProteinTensor:
|
||||
class ESMProteinTensor(ProteinType):
|
||||
sequence: torch.Tensor | None = None
|
||||
structure: torch.Tensor | None = None
|
||||
secondary_structure: torch.Tensor | None = None
|
||||
@@ -116,59 +122,33 @@ class ESMProteinTensor:
|
||||
residue_annotations: torch.Tensor | None = None
|
||||
coordinates: torch.Tensor | None = None
|
||||
|
||||
def _detect_attribute(self, func, msg):
|
||||
mapped = {k: func(k, v) for k, v in asdict(self).items() if v is not None}
|
||||
s = set(mapped.values())
|
||||
if len(s) <= 0:
|
||||
return None
|
||||
if len(s) != 1:
|
||||
raise ValueError(f"Either no tracks or inconsistent {msg}: {mapped}")
|
||||
return next(iter(s))
|
||||
|
||||
def __len__(self) -> int:
|
||||
if self.sequence is not None:
|
||||
return self.sequence.size(0)
|
||||
elif self.structure is not None:
|
||||
return self.structure.size(0)
|
||||
elif self.secondary_structure is not None:
|
||||
return self.secondary_structure.size(0)
|
||||
elif self.sasa is not None:
|
||||
return self.sasa.size(0)
|
||||
elif self.coordinates is not None:
|
||||
return self.coordinates.size(0)
|
||||
else:
|
||||
raise ValueError("No track to determine length from.")
|
||||
l = self._detect_attribute(lambda _, x: x.size(0), "length")
|
||||
return l if l is not None else 0
|
||||
|
||||
@property
|
||||
def device(self) -> str | torch.device:
|
||||
device_ = None
|
||||
|
||||
tracks = [f.name for f in attr.fields(ESMProteinTensor)]
|
||||
|
||||
for track in tracks:
|
||||
current_track: torch.Tensor | None = getattr(self, track)
|
||||
if current_track is not None:
|
||||
if device_ is not None and device_ != current_track.device:
|
||||
raise ValueError(f"Inconsistent devices for track {track}.")
|
||||
device_ = getattr(self, track).device
|
||||
|
||||
if device_ is None:
|
||||
raise ValueError("No track to determine device from.")
|
||||
|
||||
return device_
|
||||
|
||||
def to(self, device: str | torch.device | None) -> ESMProteinTensor:
|
||||
if device is None:
|
||||
return self
|
||||
|
||||
device = torch.device(device)
|
||||
d = self._detect_attribute(lambda _, x: x.device, "device")
|
||||
assert d is not None
|
||||
return d
|
||||
|
||||
def to(self, device_or_dtype: str | torch.device | torch.dtype) -> ESMProteinTensor:
|
||||
def _to(name):
|
||||
v = getattr(self, name)
|
||||
if v is not None:
|
||||
setattr(self, name, v.to(device))
|
||||
setattr(self, name, v.to(device_or_dtype))
|
||||
|
||||
for n in [
|
||||
"sequence",
|
||||
"structure",
|
||||
"secondary_structure",
|
||||
"sasa",
|
||||
"function",
|
||||
"residue_annotations",
|
||||
"coordinates",
|
||||
]:
|
||||
_to(n)
|
||||
for n in attr.fields(ESMProteinTensor):
|
||||
_to(n.name)
|
||||
|
||||
return self
|
||||
|
||||
@@ -202,6 +182,11 @@ class ESMProteinTensor:
|
||||
)
|
||||
|
||||
|
||||
@define
|
||||
class ESMProteinError(Exception, ProteinType):
|
||||
error_msg: str
|
||||
|
||||
|
||||
## High Level Endpoint Types
|
||||
@define
|
||||
class GenerationConfig:
|
||||
@@ -285,14 +270,10 @@ class ForwardAndSampleOutput(ForwardOutput):
|
||||
topk_logprob: ForwardTrackData | None = None
|
||||
# Which tokens correspond to top probability
|
||||
topk_tokens: ForwardTrackData | None = None
|
||||
|
||||
per_residue_embedding: torch.Tensor | None = None
|
||||
mean_embedding: torch.Tensor | None = None
|
||||
|
||||
|
||||
ProteinType = TypeVar("ProteinType", bound=ESMProteinTensor | ESMProtein)
|
||||
|
||||
|
||||
class ESM3InferenceClient(ABC):
|
||||
def generate(self, input: ProteinType, config: GenerationConfig) -> ProteinType:
|
||||
# This is the easiest and most flexible way to run ESM3. Generate will
|
||||
@@ -302,6 +283,14 @@ class ESM3InferenceClient(ABC):
|
||||
# if a ESMProteinTensor is provided, encode and decode are skipped
|
||||
raise NotImplementedError
|
||||
|
||||
def batch_generate(
|
||||
self,
|
||||
inputs: Sequence[ProteinType],
|
||||
configs: Sequence[GenerationConfig],
|
||||
) -> Sequence[ProteinType]:
|
||||
# Same as generate(...), but generates a batch of proteins at once.
|
||||
raise NotImplementedError
|
||||
|
||||
def encode(self, input: ESMProtein) -> ESMProteinTensor:
|
||||
# Encode allows for encoding RawRepresentation into TokenizedRepresentation.
|
||||
# This runs the structure_token_encoder, as well as dealing with PDB => atom37 conversion
|
||||
|
||||
336
esm/sdk/forge.py
Normal file
336
esm/sdk/forge.py
Normal file
@@ -0,0 +1,336 @@
|
||||
import asyncio
|
||||
from typing import Sequence
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
from esm.sdk.api import (
|
||||
ESM3InferenceClient,
|
||||
ESMProtein,
|
||||
ESMProteinError,
|
||||
ESMProteinTensor,
|
||||
ForwardAndSampleOutput,
|
||||
ForwardTrackData,
|
||||
GenerationConfig,
|
||||
ProteinType,
|
||||
SamplingConfig,
|
||||
SamplingTrackConfig,
|
||||
)
|
||||
from esm.utils.misc import maybe_list, maybe_tensor
|
||||
from esm.utils.types import FunctionAnnotation
|
||||
|
||||
|
||||
def _list_to_function_annotations(l) -> list[FunctionAnnotation] | None:
|
||||
if l is None or len(l) <= 0:
|
||||
return None
|
||||
return [FunctionAnnotation(*t) for t in l]
|
||||
|
||||
|
||||
class ESM3ForgeInferenceClient(ESM3InferenceClient):
|
||||
def __init__(self, model, url, token):
|
||||
self.model = model
|
||||
self.url = url
|
||||
self.token = token
|
||||
self.headers = {"Authorization": f"Bearer {self.token}"}
|
||||
|
||||
def generate(self, input: ProteinType, config: GenerationConfig) -> ProteinType:
|
||||
if isinstance(input, ESMProtein):
|
||||
output = self.__generate_protein(input, config)
|
||||
elif isinstance(input, ESMProteinTensor):
|
||||
output = self.__generate_protein_tensor(input, config)
|
||||
else:
|
||||
return ESMProteinError(error_msg=f"Unkonw input type {type(input)}")
|
||||
|
||||
if (
|
||||
isinstance(output, ESMProtein)
|
||||
and isinstance(input, ESMProtein)
|
||||
and config.track
|
||||
not in [
|
||||
"function",
|
||||
"residue_annotations",
|
||||
]
|
||||
):
|
||||
# Function and residue annotation encoding/decoding is lossy
|
||||
# There is no guarantee that decoding encoded tokens will yield the same input
|
||||
output.function_annotations = input.function_annotations
|
||||
|
||||
return output
|
||||
|
||||
def batch_generate(
|
||||
self, inputs: list[ProteinType], configs: list[GenerationConfig]
|
||||
) -> Sequence[ProteinType]:
|
||||
"""Forge supports auto-batching. So batch_generate() for the Forge client
|
||||
is as simple as running a collection of generate() in parallel using asyncio.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
async def _async_generate():
|
||||
futures = [
|
||||
loop.run_in_executor(None, self.generate, protein, config)
|
||||
for protein, config in zip(inputs, configs)
|
||||
]
|
||||
return await asyncio.gather(*futures, return_exceptions=True)
|
||||
|
||||
results = loop.run_until_complete(_async_generate())
|
||||
|
||||
return [
|
||||
r if not isinstance(r, BaseException) else ESMProteinError(str(r))
|
||||
for r in results
|
||||
]
|
||||
|
||||
def __generate_protein(
|
||||
self,
|
||||
input: ESMProtein,
|
||||
config: GenerationConfig,
|
||||
) -> ESMProtein | ESMProteinError:
|
||||
req = {}
|
||||
req["sequence"] = input.sequence
|
||||
req["secondary_structure"] = input.secondary_structure
|
||||
req["sasa"] = maybe_list(input.sasa)
|
||||
if input.function_annotations is not None:
|
||||
req["function"] = [x.to_tuple() for x in input.function_annotations]
|
||||
req["coordinates"] = maybe_list(input.coordinates, convert_nan_to_none=True)
|
||||
|
||||
request = {
|
||||
"model": self.model,
|
||||
"inputs": req,
|
||||
"track": config.track,
|
||||
"invalid_ids": config.invalid_ids,
|
||||
"schedule": config.schedule,
|
||||
"num_steps": config.num_steps,
|
||||
"temperature": config.temperature,
|
||||
"top_p": config.top_p,
|
||||
"condition_on_coordinates_only": config.condition_on_coordinates_only,
|
||||
}
|
||||
|
||||
try:
|
||||
data = self.__post("generate", request)
|
||||
except RuntimeError as e:
|
||||
return ESMProteinError(error_msg=str(e))
|
||||
|
||||
return ESMProtein(
|
||||
sequence=data["outputs"]["sequence"],
|
||||
secondary_structure=data["outputs"]["secondary_structure"],
|
||||
sasa=data["outputs"]["sasa"],
|
||||
function_annotations=_list_to_function_annotations(
|
||||
data["outputs"]["function"]
|
||||
),
|
||||
coordinates=maybe_tensor(
|
||||
data["outputs"]["coordinates"], convert_none_to_nan=True
|
||||
),
|
||||
plddt=maybe_tensor(data["outputs"]["plddt"]),
|
||||
ptm=maybe_tensor(data["outputs"]["ptm"]),
|
||||
)
|
||||
|
||||
def __generate_protein_tensor(
|
||||
self,
|
||||
input: ESMProteinTensor,
|
||||
config: GenerationConfig,
|
||||
) -> ESMProteinTensor | ESMProteinError:
|
||||
req = {}
|
||||
req["sequence"] = maybe_list(input.sequence)
|
||||
req["structure"] = maybe_list(input.structure)
|
||||
req["secondary_structure"] = maybe_list(input.secondary_structure)
|
||||
req["sasa"] = maybe_list(input.sasa)
|
||||
req["function"] = maybe_list(input.function)
|
||||
req["coordinates"] = maybe_list(input.coordinates, convert_nan_to_none=True)
|
||||
req["residue_annotation"] = maybe_list(input.residue_annotations)
|
||||
|
||||
request = {
|
||||
"model": self.model,
|
||||
"inputs": req,
|
||||
"track": config.track,
|
||||
"invalid_ids": config.invalid_ids,
|
||||
"schedule": config.schedule,
|
||||
"num_steps": config.num_steps,
|
||||
"temperature": config.temperature,
|
||||
"top_p": config.top_p,
|
||||
"condition_on_coordinates_only": config.condition_on_coordinates_only,
|
||||
}
|
||||
|
||||
try:
|
||||
data = self.__post("generate_tensor", request)
|
||||
except RuntimeError as e:
|
||||
return ESMProteinError(error_msg=str(e))
|
||||
|
||||
def _field_to_tensor(field, convert_none_to_nan: bool = False):
|
||||
if field not in data["outputs"]:
|
||||
return None
|
||||
return maybe_tensor(
|
||||
data["outputs"][field], convert_none_to_nan=convert_none_to_nan
|
||||
)
|
||||
|
||||
output = ESMProteinTensor(
|
||||
sequence=_field_to_tensor("sequence"),
|
||||
structure=_field_to_tensor("structure"),
|
||||
secondary_structure=_field_to_tensor("secondary_structure"),
|
||||
sasa=_field_to_tensor("sasa"),
|
||||
function=_field_to_tensor("function"),
|
||||
residue_annotations=_field_to_tensor("residue_annotation"),
|
||||
coordinates=_field_to_tensor("coordinates", convert_none_to_nan=True),
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def forward_and_sample(
|
||||
self, input: ESMProteinTensor, sampling_configuration: SamplingConfig
|
||||
) -> ForwardAndSampleOutput:
|
||||
req = {}
|
||||
sampling_config = {}
|
||||
embedding_config = None # TODO(zeming)
|
||||
|
||||
req["sequence"] = maybe_list(input.sequence)
|
||||
req["structure"] = maybe_list(input.structure)
|
||||
req["secondary_structure"] = maybe_list(input.secondary_structure)
|
||||
req["sasa"] = maybe_list(input.sasa)
|
||||
req["function"] = maybe_list(input.function)
|
||||
req["coordinates"] = maybe_list(input.coordinates, convert_nan_to_none=True)
|
||||
req["residue_annotation"] = maybe_list(input.residue_annotations)
|
||||
|
||||
def do_track(t: str):
|
||||
track: SamplingTrackConfig | None
|
||||
if (track := getattr(sampling_configuration, t, None)) is None:
|
||||
sampling_config[t] = None
|
||||
else:
|
||||
sampling_config[t] = {
|
||||
"temperature": track.temperature,
|
||||
"top_p": track.top_p,
|
||||
"only_sample_masked_tokens": track.only_sample_masked_tokens,
|
||||
"invalid_ids": track.invalid_ids,
|
||||
"topk_logprobs": track.topk_logprobs,
|
||||
}
|
||||
|
||||
do_track("sequence")
|
||||
do_track("structure")
|
||||
do_track("secondary_structure")
|
||||
do_track("sasa")
|
||||
do_track("function")
|
||||
|
||||
request = {
|
||||
"model": self.model,
|
||||
"inputs": req,
|
||||
"sampling_config": sampling_config,
|
||||
"embedding_config": embedding_config,
|
||||
}
|
||||
data = self.__post("forward_and_sample", request)
|
||||
|
||||
def get(k, field):
|
||||
if data[k] is None:
|
||||
return None
|
||||
v = data[k][field]
|
||||
return torch.tensor(v) if v is not None else None
|
||||
|
||||
tokens = ESMProteinTensor(
|
||||
sequence=get("sequence", "tokens"),
|
||||
structure=get("structure", "tokens"),
|
||||
secondary_structure=get("secondary_structure", "tokens"),
|
||||
sasa=get("sasa", "tokens"),
|
||||
function=get("function", "tokens"),
|
||||
)
|
||||
|
||||
def get_track(field):
|
||||
return ForwardTrackData(
|
||||
sequence=get("sequence", field),
|
||||
structure=get("structure", field),
|
||||
secondary_structure=get("secondary_structure", field),
|
||||
sasa=get("sasa", field),
|
||||
function=get("function", field),
|
||||
)
|
||||
|
||||
def operate_on_track(track: ForwardTrackData, fn):
|
||||
apply = lambda x: fn(x) if x is not None else None
|
||||
return ForwardTrackData(
|
||||
sequence=apply(track.sequence),
|
||||
structure=apply(track.structure),
|
||||
secondary_structure=apply(track.secondary_structure),
|
||||
sasa=apply(track.sasa),
|
||||
function=apply(track.function),
|
||||
)
|
||||
|
||||
logprob = get_track("logprobs")
|
||||
output = ForwardAndSampleOutput(
|
||||
protein_tensor=tokens,
|
||||
logprob=logprob,
|
||||
prob=operate_on_track(logprob, torch.exp),
|
||||
entropy=get_track("entropy"),
|
||||
topk_logprob=get_track("topk_logprobs"),
|
||||
topk_tokens=get_track("topk_tokens"),
|
||||
)
|
||||
return output
|
||||
|
||||
def encode(self, input: ESMProtein) -> ESMProteinTensor:
|
||||
tracks = {}
|
||||
tracks["sequence"] = input.sequence
|
||||
tracks["secondary_structure"] = input.secondary_structure
|
||||
tracks["sasa"] = input.sasa
|
||||
if input.function_annotations is not None:
|
||||
tracks["function"] = [x.to_tuple() for x in input.function_annotations]
|
||||
tracks["coordinates"] = maybe_list(input.coordinates, convert_nan_to_none=True)
|
||||
|
||||
request = {"inputs": tracks, "model": self.model}
|
||||
|
||||
data = self.__post("encode", request)
|
||||
|
||||
return ESMProteinTensor(
|
||||
sequence=maybe_tensor(data["outputs"]["sequence"]),
|
||||
structure=maybe_tensor(data["outputs"]["structure"]),
|
||||
coordinates=maybe_tensor(
|
||||
data["outputs"]["coordinates"], convert_none_to_nan=True
|
||||
),
|
||||
secondary_structure=maybe_tensor(data["outputs"]["secondary_structure"]),
|
||||
sasa=maybe_tensor(data["outputs"]["sasa"]),
|
||||
function=maybe_tensor(data["outputs"]["function"]),
|
||||
residue_annotations=maybe_tensor(data["outputs"]["residue_annotation"]),
|
||||
)
|
||||
|
||||
def decode(
|
||||
self,
|
||||
input: ESMProteinTensor,
|
||||
) -> ESMProtein:
|
||||
tokens = {}
|
||||
tokens["sequence"] = maybe_list(input.sequence)
|
||||
tokens["structure"] = maybe_list(input.structure)
|
||||
tokens["secondary_structure"] = maybe_list(input.secondary_structure)
|
||||
tokens["sasa"] = maybe_list(input.sasa)
|
||||
tokens["function"] = maybe_list(input.function)
|
||||
tokens["residue_annotation"] = maybe_list(input.residue_annotations)
|
||||
tokens["coordinates"] = maybe_list(input.coordinates, convert_nan_to_none=True)
|
||||
|
||||
request = {
|
||||
"model": self.model,
|
||||
"inputs": tokens,
|
||||
}
|
||||
|
||||
data = self.__post("decode", request)
|
||||
|
||||
return ESMProtein(
|
||||
sequence=data["outputs"]["sequence"],
|
||||
secondary_structure=data["outputs"]["secondary_structure"],
|
||||
sasa=data["outputs"]["sasa"],
|
||||
function_annotations=_list_to_function_annotations(
|
||||
data["outputs"]["function"]
|
||||
),
|
||||
coordinates=maybe_tensor(
|
||||
data["outputs"]["coordinates"], convert_none_to_nan=True
|
||||
),
|
||||
plddt=maybe_tensor(data["outputs"]["plddt"]),
|
||||
ptm=maybe_tensor(data["outputs"]["ptm"]),
|
||||
)
|
||||
|
||||
def __post(self, endpoint, request):
|
||||
response = requests.post(
|
||||
f"{self.url}/api/v1/{endpoint}",
|
||||
json=request,
|
||||
headers=self.headers,
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
raise RuntimeError(f"Failure in {endpoint}: {response.text}")
|
||||
|
||||
data = response.json()
|
||||
# Nextjs puts outputs dict under "data" key.
|
||||
# Lift it up for easier downstream processing.
|
||||
if "outputs" not in data and "data" in data:
|
||||
data = data["data"]
|
||||
|
||||
return data
|
||||
@@ -207,7 +207,7 @@ class InterProQuantizedTokenizer(EsmTokenizerBase):
|
||||
interpro_ids = []
|
||||
keywords = []
|
||||
for label in labels:
|
||||
match = re.match(r"IPR\d+", label)
|
||||
match = re.search(r"IPR\d+", label)
|
||||
if match and match.group() in self.interpro_to_index:
|
||||
interpro_ids.append(match.group())
|
||||
elif label in self._tfidf.vocab_to_index:
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from cloudpathlib import AnyPath
|
||||
|
||||
from esm.tokenization.tokenizer_base import EsmTokenizerBase
|
||||
from esm.utils.constants import esm3 as C
|
||||
@@ -25,13 +25,13 @@ class ResidueAnnotationsTokenizer(EsmTokenizerBase):
|
||||
|
||||
@cached_property
|
||||
def _description2label(self) -> dict[str, str]:
|
||||
with Path(self.csv_path).open() as f: # type: ignore
|
||||
with AnyPath(self.csv_path).open() as f: # type: ignore
|
||||
df = pd.read_csv(f)
|
||||
return dict(zip(df.label, df.label_clean))
|
||||
|
||||
@cached_property
|
||||
def _labels(self) -> list[str]:
|
||||
with Path(self.csv_path).open() as f: # type: ignore
|
||||
with AnyPath(self.csv_path).open() as f: # type: ignore
|
||||
df = pd.read_csv(f)
|
||||
labels = (
|
||||
df.groupby("label_clean")["count"]
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
|
||||
@@ -106,6 +107,8 @@ def data_root():
|
||||
]:
|
||||
if (p := Path(path)).exists():
|
||||
return p.parent
|
||||
if "INFRA_PROVIDER" in os.environ:
|
||||
return Path("")
|
||||
# Try to download from hugginface if it doesn't exist
|
||||
path = Path(snapshot_download(repo_id="EvolutionaryScale/esm3-sm-open-v1"))
|
||||
return path
|
||||
|
||||
@@ -189,10 +189,13 @@ def decode_sasa(
|
||||
sasa_tokens: torch.Tensor,
|
||||
sasa_tokenizer: SASADiscretizingTokenizer,
|
||||
) -> list[float]:
|
||||
_bos_eos_warn("SASA", sasa_tokens, sasa_tokenizer)
|
||||
if sasa_tokens[0] != 0:
|
||||
raise ValueError("SASA does not start with 0 corresponding to BOS token")
|
||||
if sasa_tokens[-1] != 0:
|
||||
raise ValueError("SASA does not end with 0 corresponding to EOS token")
|
||||
sasa_tokens = sasa_tokens[1:-1]
|
||||
|
||||
return sasa_tokenizer.decode_float(sasa_tokens)
|
||||
sasa = sasa_tokens.tolist()
|
||||
return sasa
|
||||
|
||||
|
||||
def decode_function_annotations(
|
||||
|
||||
@@ -39,7 +39,7 @@ def encode_function_annotations(
|
||||
supported_label = False
|
||||
|
||||
# Is it an InterPro label?
|
||||
if match := re.match(r"IPR\d+", fa.label):
|
||||
if match := re.search(r"IPR\d+", fa.label):
|
||||
if match.group() in function_tokens_tokenizer.interpro_to_index:
|
||||
ft_annotations.append(fa)
|
||||
supported_label = True
|
||||
|
||||
@@ -5,11 +5,11 @@ import re
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from cloudpathlib import AnyPath
|
||||
|
||||
from esm.utils.constants import esm3 as C
|
||||
|
||||
@@ -36,7 +36,7 @@ def _parse_interpro2go(path: str) -> dict[str, list[str]]:
|
||||
Returns:
|
||||
Mapping from InterPro to list of associated GO terms.
|
||||
"""
|
||||
with Path(path).open("r") as f:
|
||||
with AnyPath(path).open("r") as f:
|
||||
text = f.read()
|
||||
df = pd.Series(text.split("\n"), name="line").to_frame()
|
||||
df = df[~df.line.str.startswith("!")]
|
||||
@@ -131,7 +131,7 @@ class InterPro:
|
||||
- "type": InterProEntryType representing the type of annotation.
|
||||
- "name": Short name of the entry.
|
||||
"""
|
||||
with Path(self.entries_path).open("r") as f:
|
||||
with AnyPath(self.entries_path).open("r") as f:
|
||||
df = pd.read_csv(f, sep="\t")
|
||||
assert all(
|
||||
col in df.columns for col in ["ENTRY_AC", "ENTRY_TYPE", "ENTRY_NAME"]
|
||||
@@ -178,7 +178,7 @@ class InterPro:
|
||||
def graph(self) -> nx.DiGraph:
|
||||
"""Reads the InterPro hierarchy of InterPro."""
|
||||
graph = nx.DiGraph()
|
||||
with Path(self.hierarchy_graph_path).open("r") as f:
|
||||
with AnyPath(self.hierarchy_graph_path).open("r") as f:
|
||||
parents = []
|
||||
for line in f:
|
||||
ipr = line.split("::", maxsplit=1)[0]
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from cloudpathlib import AnyPath
|
||||
|
||||
from esm.utils.types import PathLike
|
||||
|
||||
@@ -42,7 +41,7 @@ class LSHTokenized:
|
||||
):
|
||||
table_hyperplanes = None
|
||||
if filepath is not None:
|
||||
filepath = Path(filepath)
|
||||
filepath = AnyPath(filepath)
|
||||
if not filepath.exists():
|
||||
raise FileNotFoundError(filepath)
|
||||
table_hyperplanes = np.load(filepath) # type: ignore
|
||||
@@ -83,7 +82,7 @@ class LSHBitstream:
|
||||
):
|
||||
table_hyperplanes = None
|
||||
if filepath is not None:
|
||||
filepath = Path(filepath)
|
||||
filepath = AnyPath(filepath)
|
||||
if not filepath.exists():
|
||||
raise FileNotFoundError(filepath)
|
||||
table_hyperplanes = np.load(filepath)
|
||||
|
||||
@@ -4,6 +4,7 @@ from collections import Counter
|
||||
from functools import cached_property
|
||||
|
||||
import numpy as np
|
||||
from cloudpathlib import AnyPath
|
||||
from scipy import sparse
|
||||
|
||||
|
||||
@@ -13,10 +14,10 @@ class TFIDFModel:
|
||||
"""
|
||||
|
||||
def __init__(self, vocabulary_path: str, idf_path: str):
|
||||
with open(vocabulary_path, "r") as f:
|
||||
with AnyPath(vocabulary_path).open("r") as f:
|
||||
self.vocabulary = f.read().strip().split("\n")
|
||||
|
||||
with open(idf_path, "rb") as f:
|
||||
with AnyPath(idf_path).open("rb") as f:
|
||||
self.idf_ = np.load(f)
|
||||
|
||||
assert self.idf_.ndim == 1
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from typing import Callable
|
||||
import os
|
||||
from typing import Any, Callable, Sequence
|
||||
from warnings import warn
|
||||
|
||||
import attr
|
||||
import torch
|
||||
@@ -7,8 +9,14 @@ from tqdm import tqdm
|
||||
from esm.sdk.api import (
|
||||
ESM3InferenceClient,
|
||||
ESMProtein,
|
||||
ESMProteinError,
|
||||
ESMProteinTensor,
|
||||
ForwardAndSampleOutput,
|
||||
ForwardConfig,
|
||||
ForwardOutput,
|
||||
ForwardTrackData,
|
||||
GenerationConfig,
|
||||
ReturnLogitsConfig,
|
||||
SamplingConfig,
|
||||
SamplingTrackConfig,
|
||||
)
|
||||
@@ -16,170 +24,694 @@ from esm.tokenization import (
|
||||
EsmTokenizerBase,
|
||||
TokenizerCollectionProtocol,
|
||||
)
|
||||
from esm.tokenization.function_tokenizer import (
|
||||
InterProQuantizedTokenizer,
|
||||
)
|
||||
from esm.utils.constants import esm3 as C
|
||||
from esm.utils.noise_schedules import NOISE_SCHEDULE_REGISTRY
|
||||
from esm.utils.sampling import (
|
||||
_BatchedESMProteinTensor,
|
||||
sample_function_logits,
|
||||
sample_logits,
|
||||
sample_residue_annotation_logits,
|
||||
sample_sasa_logits,
|
||||
)
|
||||
from evolutionaryscale.utils.tensor import (
|
||||
stack_variable_length_tensors,
|
||||
)
|
||||
|
||||
|
||||
def _trim_sequence_tensor_dataclass(o: Any, sequence_len: int):
|
||||
"""Trim tensors on the sequence dimension.
|
||||
|
||||
This util assume that input tensor class has batch dimension.
|
||||
"""
|
||||
assert attr.has(o.__class__)
|
||||
|
||||
sliced = {}
|
||||
for k, v in attr.asdict(o, recurse=False).items():
|
||||
if v is None:
|
||||
sliced[k] = None
|
||||
elif isinstance(v, torch.Tensor):
|
||||
# Trim padding.
|
||||
sliced[k] = v[:, :sequence_len]
|
||||
elif attr.has(v.__class__):
|
||||
# Recursively slice the child attribute.
|
||||
sliced[k] = _trim_sequence_tensor_dataclass(v, sequence_len)
|
||||
else:
|
||||
# Otherwise, simply copy the entire data bit over.
|
||||
sliced[k] = v
|
||||
|
||||
return attr.evolve(o, **sliced)
|
||||
|
||||
|
||||
def _slice_tensor_dataclass(o: Any, i: int, keep_dim: bool = False) -> Any:
|
||||
"""Take a slice out of any attr defined Tensor objects along the batch dimension.
|
||||
|
||||
Args:
|
||||
o: input tensor object to be sliced.
|
||||
i: index of the row to be sliced.
|
||||
keep_dim: whether to keep the batch dim after slicing.
|
||||
For example, given a tensor of shape (5, 8), if keep_dim is True,
|
||||
return a sliced tensor of shape (1, 8). Return a tensor of shape
|
||||
(8,) instead if keep_dim is False. The default is False.
|
||||
"""
|
||||
assert attr.has(o.__class__)
|
||||
|
||||
sliced = {}
|
||||
for k, v in attr.asdict(o, recurse=False).items():
|
||||
if v is None:
|
||||
sliced[k] = None
|
||||
elif isinstance(v, torch.Tensor):
|
||||
# Select the i-th row of each tensor.
|
||||
row = v.select(0, i)
|
||||
if keep_dim:
|
||||
row = row.unsqueeze(0)
|
||||
sliced[k] = row
|
||||
elif attr.has(v.__class__):
|
||||
# Recursively slice the child attribute.
|
||||
sliced[k] = _slice_tensor_dataclass(v, i, keep_dim)
|
||||
else:
|
||||
# Otherwise, simply copy the entire data bit over.
|
||||
sliced[k] = v
|
||||
|
||||
return attr.evolve(o, **sliced)
|
||||
|
||||
|
||||
def iterative_sampling_raw(
|
||||
client: ESM3InferenceClient,
|
||||
input: ESMProtein,
|
||||
config: GenerationConfig,
|
||||
):
|
||||
proteins: list[ESMProtein],
|
||||
configs: list[GenerationConfig],
|
||||
) -> list[ESMProtein | ESMProteinError]:
|
||||
# Keep structure tokens
|
||||
input_tokens = client.encode(input)
|
||||
input_tokens = [client.encode(protein) for protein in proteins]
|
||||
|
||||
output_tokens = client.generate(input_tokens, config)
|
||||
output_tokens_list = client.batch_generate(input_tokens, configs)
|
||||
|
||||
raw_protein = client.decode(output_tokens)
|
||||
raw_proteins: list[ESMProtein | ESMProteinError] = []
|
||||
for output_tokens in output_tokens_list:
|
||||
if isinstance(output_tokens, ESMProteinTensor):
|
||||
raw_proteins.append(client.decode(output_tokens))
|
||||
elif isinstance(output_tokens, ESMProteinError):
|
||||
raw_proteins.append(output_tokens)
|
||||
else:
|
||||
raise ValueError(f"Unknown output type {type(output_tokens)}")
|
||||
|
||||
for input_protein, raw_protein, config in zip(proteins, raw_proteins, configs):
|
||||
if isinstance(raw_protein, ESMProteinError):
|
||||
# If this generation errored out.
|
||||
continue
|
||||
if config.track not in ["function", "residue_annotations"]:
|
||||
# Function and residue annotation encoding/decoding is lossy
|
||||
# There is no guarantee that decoding encoded tokens will yield the same input
|
||||
raw_protein.function_annotations = input_protein.function_annotations
|
||||
|
||||
return raw_proteins
|
||||
|
||||
|
||||
def _make_masked_inputs(
|
||||
track: str,
|
||||
sequence_length: int,
|
||||
tokenizers: TokenizerCollectionProtocol,
|
||||
):
|
||||
get_tokenizer: Callable[[str], EsmTokenizerBase] = lambda s: getattr(tokenizers, s)
|
||||
|
||||
if track == "coordinates":
|
||||
dims = (sequence_length, 3, 3)
|
||||
elif track == "attention_mask":
|
||||
dims = (sequence_length,)
|
||||
elif track == "function":
|
||||
dims = (sequence_length, tokenizers.function.depth)
|
||||
elif track == "residue_annotations":
|
||||
dims = (sequence_length, C.MAX_RESIDUE_ANNOTATIONS)
|
||||
else:
|
||||
dims = (sequence_length,)
|
||||
|
||||
if track == "coordinates":
|
||||
masked_tokens = torch.full(dims, torch.inf, dtype=torch.float)
|
||||
elif track == "attention_mask":
|
||||
masked_tokens = torch.full(dims, 1, dtype=torch.bool)
|
||||
else:
|
||||
masked_tokens = torch.full(
|
||||
dims, get_tokenizer(track).mask_token_id, dtype=torch.long
|
||||
)
|
||||
masked_tokens[0] = get_tokenizer(track).bos_token_id
|
||||
masked_tokens[-1] = get_tokenizer(track).eos_token_id
|
||||
|
||||
return masked_tokens
|
||||
|
||||
|
||||
def _stack_protein_tensors(
|
||||
input_tokens: list[ESMProteinTensor],
|
||||
sequence_lengths: list[int],
|
||||
tokenizers: TokenizerCollectionProtocol,
|
||||
device: str | torch.device,
|
||||
) -> _BatchedESMProteinTensor:
|
||||
o = _BatchedESMProteinTensor()
|
||||
|
||||
def _stack_field(fn: str):
|
||||
tensors = [getattr(tokens, fn) for tokens in input_tokens]
|
||||
|
||||
# Create all mask mock inputs for any tensors that are None.
|
||||
tensors = [
|
||||
t if t is not None else _make_masked_inputs(fn, l, tokenizers).to(device)
|
||||
for t, l in zip(tensors, sequence_lengths)
|
||||
]
|
||||
|
||||
if fn == "coordinates":
|
||||
mask_token_id = torch.inf
|
||||
else:
|
||||
mask_token_id = getattr(tokenizers, fn).pad_token_id
|
||||
|
||||
setattr(
|
||||
o,
|
||||
fn,
|
||||
stack_variable_length_tensors(
|
||||
sequences=tensors,
|
||||
constant_value=mask_token_id,
|
||||
),
|
||||
)
|
||||
|
||||
for f in attr.fields(ESMProteinTensor):
|
||||
_stack_field(f.name)
|
||||
|
||||
return o
|
||||
|
||||
|
||||
def _get_masked_positions(
|
||||
track: str, tokens: torch.Tensor, mask_token_id: int
|
||||
) -> torch.Tensor:
|
||||
if track == "function":
|
||||
return torch.all(tokens == mask_token_id, dim=-1).to(tokens.device)
|
||||
else:
|
||||
return tokens == mask_token_id
|
||||
|
||||
|
||||
def _get_iterative_sampling_mask_for_prompt_and_step(
|
||||
cur_sampled: _BatchedESMProteinTensor,
|
||||
sequence_lengths: torch.Tensor,
|
||||
total_to_sample: torch.Tensor,
|
||||
step: int,
|
||||
entropy: ForwardTrackData,
|
||||
config: GenerationConfig,
|
||||
tokenizers: TokenizerCollectionProtocol,
|
||||
) -> torch.Tensor:
|
||||
"""Get sampling mask based on forward output and config.
|
||||
|
||||
Returns:
|
||||
Sampling mask and num of positions sampled.
|
||||
"""
|
||||
track_to_sample = config.track
|
||||
tokens = getattr(cur_sampled, track_to_sample)
|
||||
device = tokens.device
|
||||
|
||||
if track_to_sample not in ["function", "residue_annotations"]:
|
||||
# Function and residue annotation encoding/decoding is lossy
|
||||
# There is no guarantee that decoding encoded tokens will yield the same input
|
||||
raw_protein.function_annotations = input.function_annotations
|
||||
shape = tokens.shape
|
||||
B, L = shape[0], shape[1]
|
||||
|
||||
return raw_protein
|
||||
sampling_mask = torch.ones((B, L), dtype=torch.bool, device=device)
|
||||
sampling_mask[:, 0] = False # BOS
|
||||
# EOS and all padding tokens.
|
||||
sampling_mask &= (
|
||||
torch.arange(L).repeat(B, 1) < (sequence_lengths - 1).unsqueeze(-1)
|
||||
).to(device)
|
||||
|
||||
is_mask = _get_masked_positions(
|
||||
track_to_sample,
|
||||
tokens,
|
||||
getattr(tokenizers, track_to_sample).mask_token_id,
|
||||
)
|
||||
if not is_mask.any().item():
|
||||
raise ValueError(f"Cannot sample {config.track} when input has no masks.")
|
||||
sampling_mask = sampling_mask & is_mask
|
||||
|
||||
# Initialize schedule and masks
|
||||
decoding_schedule = NOISE_SCHEDULE_REGISTRY[config.schedule]
|
||||
|
||||
# Calculate number of tokens to sample
|
||||
still_masked = torch.sum(sampling_mask).int()
|
||||
perc_masked_after_this_step = decoding_schedule(
|
||||
torch.tensor((step + 1) / config.num_steps)
|
||||
)
|
||||
num_tokens_masked_after_this_step = (
|
||||
perc_masked_after_this_step * total_to_sample
|
||||
).int()
|
||||
num_to_sample = still_masked - num_tokens_masked_after_this_step
|
||||
|
||||
track_entropy: torch.Tensor = getattr(
|
||||
entropy, track_to_sample
|
||||
) # (B, L) or (B, L, D)
|
||||
|
||||
if track_to_sample == "function":
|
||||
track_entropy = track_entropy.sum(-1) # (B, L, D) -> (B, L)
|
||||
|
||||
track_entropy = track_entropy.masked_fill(
|
||||
~sampling_mask, torch.finfo(track_entropy.dtype).max
|
||||
)
|
||||
_, indices = track_entropy.topk(num_to_sample, dim=-1, largest=False)
|
||||
is_top_k = torch.zeros((B, L), dtype=torch.bool, device=device).scatter(
|
||||
1, indices, True
|
||||
)
|
||||
where_to_sample = sampling_mask & is_top_k
|
||||
|
||||
if track_to_sample == "function":
|
||||
where_to_sample = where_to_sample.unsqueeze(-1).expand(
|
||||
B,
|
||||
L,
|
||||
tokenizers.function.depth,
|
||||
) # (B, L) -> (B, L, D)
|
||||
|
||||
return where_to_sample
|
||||
|
||||
|
||||
def iterative_sampling_tokens(
|
||||
client: ESM3InferenceClient,
|
||||
input_tokens: ESMProteinTensor,
|
||||
config: GenerationConfig,
|
||||
input_tokens: list[ESMProteinTensor],
|
||||
configs: list[GenerationConfig],
|
||||
tokenizers: TokenizerCollectionProtocol,
|
||||
) -> ESMProteinTensor:
|
||||
track_to_sample = config.track
|
||||
) -> Sequence[ESMProteinTensor | ESMProteinError]:
|
||||
devices = set([t.device for t in input_tokens])
|
||||
if len(devices) > 1:
|
||||
raise AttributeError(f"Input tokens on multiple devices {devices}")
|
||||
|
||||
# Get all tracks that require sampling
|
||||
all_tracks = [
|
||||
f.name for f in attr.fields(SamplingConfig) if "embedding" not in f.name
|
||||
]
|
||||
sampled_tokens = [attr.evolve(tokens) for tokens in input_tokens]
|
||||
|
||||
sequence_length = len(input_tokens)
|
||||
device = input_tokens.device
|
||||
# Clear structure tokens if user would like to condition only on coordinates.
|
||||
for tokens, config in zip(sampled_tokens, configs):
|
||||
if config.condition_on_coordinates_only and tokens.coordinates is not None:
|
||||
tokens.structure = None
|
||||
|
||||
# Initialize schedule and masks
|
||||
decoding_schedule = NOISE_SCHEDULE_REGISTRY[config.schedule]
|
||||
sampled_tokens = attr.evolve(input_tokens) # Make a copy
|
||||
# Total sequence lengths.
|
||||
sequence_lengths = [len(tokens) for tokens in sampled_tokens]
|
||||
# Figure out the number of tokens to be sampled for each prompt.
|
||||
total_to_sample = []
|
||||
for protein, seq_len, config in zip(input_tokens, sequence_lengths, configs):
|
||||
track = config.track
|
||||
|
||||
if config.condition_on_coordinates_only and input_tokens.coordinates is not None:
|
||||
sampled_tokens.structure = None
|
||||
if getattr(protein, track) is None:
|
||||
# We need to sample the entire track.
|
||||
total_to_sample.append(seq_len - 2)
|
||||
continue
|
||||
|
||||
sampling_mask = torch.ones(
|
||||
sequence_length,
|
||||
dtype=torch.bool,
|
||||
device=device,
|
||||
masked = _get_masked_positions(
|
||||
track,
|
||||
getattr(protein, track),
|
||||
getattr(tokenizers, track).mask_token_id,
|
||||
)
|
||||
total_to_sample.append(torch.sum(masked))
|
||||
|
||||
# Different prompts may ask for different number of decoding steps.
|
||||
# For now, we simply run the max number of steps.
|
||||
# TODO: return completed proteins as soon as they are finished sampling.
|
||||
max_num_steps = max([config.num_steps for config in configs])
|
||||
|
||||
# Now stack the list to make a single batched ESMProteinTensor.
|
||||
batched_tokens = _stack_protein_tensors(
|
||||
input_tokens,
|
||||
sequence_lengths,
|
||||
tokenizers,
|
||||
devices.pop(),
|
||||
)
|
||||
sampling_mask[0] = False
|
||||
sampling_mask[-1] = False
|
||||
|
||||
get_tokenizer: Callable[[str], EsmTokenizerBase] = lambda s: getattr(tokenizers, s)
|
||||
if getattr(sampled_tokens, track_to_sample) is None:
|
||||
if track_to_sample == "function":
|
||||
dims = (sequence_length, tokenizers.function.depth)
|
||||
elif track_to_sample == "residue_annotations":
|
||||
dims = (sequence_length, C.MAX_RESIDUE_ANNOTATIONS)
|
||||
else:
|
||||
dims = (sequence_length,)
|
||||
masked_tokens = torch.full(
|
||||
dims,
|
||||
get_tokenizer(track_to_sample).mask_token_id,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
if track_to_sample == "sequence":
|
||||
masked_tokens[0] = tokenizers.sequence.cls_token_id # type: ignore
|
||||
masked_tokens[-1] = tokenizers.sequence.eos_token_id # type: ignore
|
||||
else:
|
||||
masked_tokens[0] = get_tokenizer(track_to_sample).bos_token_id
|
||||
masked_tokens[-1] = get_tokenizer(track_to_sample).eos_token_id
|
||||
|
||||
setattr(
|
||||
sampled_tokens,
|
||||
track_to_sample,
|
||||
masked_tokens,
|
||||
)
|
||||
else:
|
||||
is_mask: torch.Tensor = (
|
||||
getattr(input_tokens, track_to_sample)
|
||||
== get_tokenizer(track_to_sample).mask_token_id
|
||||
)
|
||||
if not is_mask.any().item():
|
||||
raise ValueError(f"Cannot sample {config.track} when input has no masks.")
|
||||
sampling_mask = sampling_mask & is_mask
|
||||
# Remember sampled prompts that has somehow errored out.
|
||||
errors: dict[int, ESMProteinError] = {}
|
||||
|
||||
# Decode
|
||||
disable_tqdm = bool(os.environ.get("DISABLE_ITERATIVE_SAMPLING_TQDM", False))
|
||||
for t in tqdm(range(max_num_steps), disable=disable_tqdm):
|
||||
forward_out = _batch_forward(client, batched_tokens)
|
||||
|
||||
# Sample each prompt individually, since their configuration may
|
||||
# be very different.
|
||||
# TODO: downstream utils work with batch dimsension.
|
||||
# Group by sampling configurations and sample those prompts together.
|
||||
for i, config in enumerate(configs): # B
|
||||
if i in errors:
|
||||
# This prompts has errored out in previous steps.
|
||||
# Skip.
|
||||
continue
|
||||
|
||||
if config.track in ["coordinates", "residue_annotations"]:
|
||||
errors[i] = ESMProteinError(
|
||||
error_msg=f"Iterative sampling {config.track} is not supported."
|
||||
)
|
||||
continue
|
||||
|
||||
if t >= config.num_steps:
|
||||
# Done sampling for this row.
|
||||
continue
|
||||
|
||||
per_prompt_cur_sampled = _BatchedESMProteinTensor.from_protein_tensor(
|
||||
batched_tokens.slice(i)
|
||||
)
|
||||
per_prompt_forward_out: ForwardOutput = _slice_tensor_dataclass(
|
||||
forward_out, i, keep_dim=True
|
||||
)
|
||||
# Trim logits to proper sequence length for this prompt.
|
||||
per_prompt_forward_out.logits = _trim_sequence_tensor_dataclass(
|
||||
per_prompt_forward_out.logits,
|
||||
# Note(jungong) : we can not smiply use sequence_lenths[i] here,
|
||||
# what we want is for the sequence length of the logits to match
|
||||
# that of the prompt, which may or may not be padded, depending on
|
||||
# whether the padding was done locally with the open source model
|
||||
# (where per_prompt_cur_sampled is already padded) or by
|
||||
# BatchedForwardRunner (where per_prompt_cur_sampled is not padded).
|
||||
len(per_prompt_cur_sampled),
|
||||
)
|
||||
|
||||
track_sample_config = SamplingTrackConfig()
|
||||
track_sample_config.invalid_ids = config.invalid_ids
|
||||
track_sample_config.temperature = config.temperature
|
||||
track_sample_config.top_p = config.top_p
|
||||
sampling_config = SamplingConfig(**{config.track: track_sample_config}) # type: ignore
|
||||
|
||||
# Sampling has to be done per-prompt, since sampling configs
|
||||
# are likely be different for different prompts.
|
||||
per_prompt_forward_and_sample_output = _sample_per_prompt(
|
||||
per_prompt_cur_sampled,
|
||||
per_prompt_forward_out,
|
||||
sampling_config,
|
||||
tokenizers,
|
||||
)
|
||||
|
||||
# All positions sampled after _sample_per_prompt() above.
|
||||
# (B, L) & (B, L, D)
|
||||
per_prompt_new_sampled = per_prompt_forward_and_sample_output.protein_tensor
|
||||
|
||||
# Find the positions we should sample this round.
|
||||
assert per_prompt_forward_and_sample_output.entropy is not None
|
||||
try:
|
||||
where_to_sample = _get_iterative_sampling_mask_for_prompt_and_step(
|
||||
per_prompt_cur_sampled,
|
||||
torch.tensor(sequence_lengths[i]),
|
||||
torch.tensor(total_to_sample[i]),
|
||||
t,
|
||||
per_prompt_forward_and_sample_output.entropy,
|
||||
config,
|
||||
tokenizers,
|
||||
)
|
||||
except ValueError as e:
|
||||
errors[i] = ESMProteinError(error_msg=str(e))
|
||||
continue
|
||||
|
||||
where_to_sample.to(input_tokens[0].device)
|
||||
|
||||
old_track_samples = getattr(per_prompt_cur_sampled, config.track)
|
||||
new_track_samples = getattr(per_prompt_new_sampled, config.track)
|
||||
|
||||
# Iterative sampling by picking the tokens sampled this round
|
||||
# from new_track_samples to old_track_samples.
|
||||
new_track_samples = torch.where(
|
||||
where_to_sample, new_track_samples, old_track_samples
|
||||
)
|
||||
|
||||
# Update the corresponding row with new data.
|
||||
getattr(batched_tokens, config.track)[i, ...] = new_track_samples[0]
|
||||
|
||||
# Un-pack to a list of single ProteinTypes.
|
||||
output_tokens = [
|
||||
batched_tokens.slice(i, sequence_len=sequence_lengths[i])
|
||||
if i not in errors
|
||||
else errors[i]
|
||||
for i in range(len(input_tokens))
|
||||
]
|
||||
|
||||
# Do not update tracks that were not sampled (e.g. keep None instead of masks)
|
||||
for inputs, outputs, config in zip(input_tokens, output_tokens, configs):
|
||||
if isinstance(outputs, ESMProteinError):
|
||||
continue
|
||||
|
||||
# First restore coordinates field.
|
||||
# We know coordinates can never be iteratively sampled.
|
||||
setattr(outputs, "coordinates", getattr(inputs, "coordinates"))
|
||||
# Maybe restore all the other fields.
|
||||
for f in attr.fields(SamplingConfig):
|
||||
if "embedding" in f.name:
|
||||
continue
|
||||
if f.name != config.track:
|
||||
setattr(outputs, f.name, getattr(inputs, f.name))
|
||||
|
||||
return output_tokens
|
||||
|
||||
|
||||
def _batch_forward(
|
||||
client: ESM3InferenceClient,
|
||||
protein: _BatchedESMProteinTensor,
|
||||
):
|
||||
# Forward pass
|
||||
return client._forward(
|
||||
protein,
|
||||
ForwardConfig(
|
||||
ReturnLogitsConfig(
|
||||
sequence=True,
|
||||
structure=True,
|
||||
secondary_structure=True,
|
||||
sasa=True,
|
||||
function=True,
|
||||
residue_annotations=True,
|
||||
),
|
||||
return_embeddings=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _sample_per_prompt(
|
||||
protein: _BatchedESMProteinTensor,
|
||||
forward_output: ForwardOutput,
|
||||
sampling_config: SamplingConfig,
|
||||
tokenizers: TokenizerCollectionProtocol,
|
||||
) -> ForwardAndSampleOutput:
|
||||
def maybe_clone(x: torch.Tensor | None) -> torch.Tensor | None:
|
||||
return x.clone() if x is not None else None
|
||||
|
||||
L = sequence_length - 2
|
||||
positions_sampled = 0
|
||||
for t in tqdm(range(config.num_steps)):
|
||||
# Single step sampling at all positions
|
||||
track_sample_config = SamplingTrackConfig()
|
||||
track_sample_config.invalid_ids = config.invalid_ids
|
||||
track_sample_config.temperature = config.temperature
|
||||
track_sample_config.top_p = config.top_p
|
||||
sampling_config = SamplingConfig(**{track_to_sample: track_sample_config}) # type: ignore
|
||||
|
||||
forward_and_sample_output = client.forward_and_sample(
|
||||
sampled_tokens, sampling_config
|
||||
)
|
||||
new_samples = forward_and_sample_output.protein_tensor
|
||||
|
||||
# Calculate number of tokens to sample
|
||||
perc_masked = decoding_schedule(torch.tensor((t + 1) / config.num_steps))
|
||||
num_to_sample = int((1 - perc_masked) * L) - positions_sampled
|
||||
positions_sampled += num_to_sample
|
||||
|
||||
# Select tokens based on lowest entropy
|
||||
if track_to_sample in ["function", "residue_annotations"]:
|
||||
# TODO: Implement iterative decoding for function and residue_annotations
|
||||
# TODO: Fix encode/decode of interpro tokens (not yet supported)
|
||||
sampled_tokens.function = maybe_clone(input_tokens.function)
|
||||
sampled_tokens.residue_annotations = maybe_clone(
|
||||
input_tokens.residue_annotations
|
||||
)
|
||||
if track_to_sample in track_to_sample:
|
||||
raise NotImplementedError(
|
||||
f"Iterative decoding for {track_to_sample} is not supported yet."
|
||||
)
|
||||
# Sampling
|
||||
tokens_dir = {}
|
||||
track_sampling_metadata_dir: dict[str, dict | None] = {}
|
||||
for track in ["sequence", "structure", "secondary_structure"]:
|
||||
config = getattr(sampling_config, track)
|
||||
if config is None:
|
||||
tokens_dir[track] = maybe_clone(getattr(protein, track))
|
||||
continue
|
||||
|
||||
sampling_mask = sampling_mask & (
|
||||
getattr(sampled_tokens, track_to_sample)
|
||||
== get_tokenizer(track_to_sample).mask_token_id
|
||||
sampling_metadata = _sample_track(
|
||||
logits=getattr(forward_output.logits, track),
|
||||
tokens=getattr(protein, track),
|
||||
sampling_track_config=config,
|
||||
mask_idx=getattr(tokenizers, track).mask_token_id,
|
||||
)
|
||||
tokens_dir[track] = sampling_metadata.pop("sampled_tokens") # (L,)
|
||||
track_sampling_metadata_dir[track] = sampling_metadata
|
||||
|
||||
track_entropy: torch.Tensor = getattr(
|
||||
forward_and_sample_output.entropy, track_to_sample
|
||||
# Sample SASA seperately
|
||||
config = getattr(sampling_config, "sasa")
|
||||
track_sampling_metadata_dir["sasa"] = None
|
||||
|
||||
if config is not None:
|
||||
if config.topk_logprobs > 0:
|
||||
warn("For SASA sampling, 'topk_logprobs' is expected to be 0.")
|
||||
sasa_logits = forward_output.logits.sasa[0, ...] # type: ignore
|
||||
sasa_value = sample_sasa_logits(sasa_logits, protein.sasa[0, ...]) # type: ignore
|
||||
tokens_dir["sasa"] = sasa_value
|
||||
|
||||
probs = sasa_logits.softmax(dim=-1)
|
||||
entropy = -(probs * sasa_logits.log_softmax(-1)).sum(-1)
|
||||
|
||||
track_sampling_metadata_dir["sasa"] = {"entropy": entropy}
|
||||
|
||||
# Sample function and residue annotations separately
|
||||
config = getattr(sampling_config, "function")
|
||||
if config is None:
|
||||
tokens_dir["function"] = maybe_clone(getattr(protein, "function"))
|
||||
tokens_dir["residue_annotations"] = maybe_clone(
|
||||
getattr(protein, "residue_annotations")
|
||||
)
|
||||
track_entropy = track_entropy.masked_fill(
|
||||
~sampling_mask, torch.finfo(track_entropy.dtype).max
|
||||
else:
|
||||
sampling_metadata = _sample_function_track(
|
||||
tokenizers.function,
|
||||
tokens=getattr(protein, "function"),
|
||||
logits=getattr(forward_output.logits, "function"),
|
||||
sampling_track_config=config,
|
||||
)
|
||||
_, indices = track_entropy.topk(num_to_sample, dim=-1, largest=False)
|
||||
is_top_k = ~(
|
||||
torch.arange(sequence_length, device=device)[:, None] != indices[None, :]
|
||||
).all(-1)
|
||||
tokens_to_sample = sampling_mask & is_top_k
|
||||
tokens_dir["function"] = sampling_metadata.pop("sampled_tokens") # (L, D)
|
||||
track_sampling_metadata_dir["function"] = sampling_metadata
|
||||
|
||||
old_track_samples = getattr(sampled_tokens, track_to_sample)
|
||||
new_track_samples = getattr(new_samples, track_to_sample)
|
||||
|
||||
new_track_samples = torch.where(
|
||||
tokens_to_sample, new_track_samples, old_track_samples
|
||||
sampled_tokens, _ = sample_residue_annotation_logits(
|
||||
logits=forward_output.residue_annotation_logits # type: ignore
|
||||
)
|
||||
tokens_dir["residue_annotations"] = sampled_tokens # (L, MAX_R)
|
||||
|
||||
setattr(sampled_tokens, track_to_sample, new_track_samples)
|
||||
|
||||
# Do not update tracks that were not sampled (e.g. keep None instead of masks)
|
||||
for track in all_tracks:
|
||||
if track != track_to_sample:
|
||||
setattr(
|
||||
sampled_tokens,
|
||||
track,
|
||||
maybe_clone(getattr(input_tokens, track)),
|
||||
# Format output
|
||||
forward_and_sample_output_dir = {}
|
||||
forward_and_sample_output_dir["protein_tensor"] = ESMProteinTensor(**tokens_dir)
|
||||
for property in [
|
||||
"entropy",
|
||||
"prob",
|
||||
"logprob",
|
||||
"top_prob",
|
||||
"topk_logprob",
|
||||
"topk_tokens",
|
||||
]:
|
||||
is_all_none = True
|
||||
forward_track_data_dir = {}
|
||||
for track in track_sampling_metadata_dir.keys():
|
||||
values = track_sampling_metadata_dir[track]
|
||||
if values is not None and values.get(property, None) is not None:
|
||||
forward_track_data_dir[track] = values.get(property, None)
|
||||
is_all_none = False
|
||||
if not is_all_none:
|
||||
forward_and_sample_output_dir[property] = ForwardTrackData(
|
||||
**forward_track_data_dir
|
||||
)
|
||||
else:
|
||||
forward_and_sample_output_dir[property] = None
|
||||
|
||||
return sampled_tokens
|
||||
per_res_embed = (
|
||||
forward_output.embeddings # type: ignore
|
||||
if sampling_config.return_per_residue_embeddings
|
||||
else None
|
||||
)
|
||||
mean_embedding = (
|
||||
# [B, L, D] -> [B, D]
|
||||
forward_output.embeddings[0].mean(dim=1) # type: ignore
|
||||
if sampling_config.return_mean_embedding
|
||||
else None
|
||||
)
|
||||
|
||||
return ForwardAndSampleOutput(
|
||||
per_residue_embedding=per_res_embed,
|
||||
mean_embedding=mean_embedding,
|
||||
**forward_and_sample_output_dir,
|
||||
)
|
||||
|
||||
|
||||
def _sample_track(
|
||||
logits: torch.Tensor,
|
||||
tokens: torch.Tensor,
|
||||
sampling_track_config: SamplingTrackConfig,
|
||||
mask_idx: int,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Works with inputs that have batch dimension."""
|
||||
# Sample in all positions
|
||||
temperature = sampling_track_config.temperature
|
||||
# We have to trim the logits and sampled tokens at potentially padded slots
|
||||
# since the logits may be computed with a longer padded batch, while tokens
|
||||
# are the original input sequence.
|
||||
sampled_tokens = sample_logits(
|
||||
logits, temperature=temperature, top_p=sampling_track_config.top_p
|
||||
)
|
||||
log_probs = logits.log_softmax(-1)
|
||||
|
||||
# Do not sample at BOS and EOS tokens
|
||||
sampling_mask = torch.ones_like(tokens, dtype=torch.bool) # (B, L, )
|
||||
sampling_mask[:, 0] = False
|
||||
sampling_mask[:, -1] = False
|
||||
|
||||
# Do not sample at special token positions but allow sampling at mask token
|
||||
special_minus_mask = list(set(sampling_track_config.invalid_ids) - {mask_idx})
|
||||
if len(special_minus_mask) > 0:
|
||||
special_tokens = torch.tensor(special_minus_mask, device=tokens.device)
|
||||
assert special_tokens.numel() > 0
|
||||
sampling_mask = sampling_mask & (
|
||||
tokens[..., None] != special_tokens[None, :]
|
||||
).all(-1)
|
||||
|
||||
# Keep only samples from masked positions (if specified)
|
||||
if sampling_track_config.only_sample_masked_tokens:
|
||||
masked_tokens = tokens == mask_idx
|
||||
sampling_mask = sampling_mask & masked_tokens
|
||||
sampled_tokens = torch.where(sampling_mask, sampled_tokens, tokens)
|
||||
|
||||
return _compute_track_metadata(
|
||||
sampled_tokens,
|
||||
log_probs,
|
||||
sampling_mask,
|
||||
top_k=sampling_track_config.topk_logprobs,
|
||||
)
|
||||
|
||||
|
||||
def _sample_function_track(
|
||||
function_tokenizer: InterProQuantizedTokenizer,
|
||||
tokens: torch.Tensor,
|
||||
logits: torch.Tensor,
|
||||
sampling_track_config: SamplingTrackConfig,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Works with inputs that have batch dimension."""
|
||||
# Do not sample at BOS and EOS tokens
|
||||
sampling_mask = torch.ones_like(tokens, dtype=torch.bool)[..., 0] # (B, L)
|
||||
sampling_mask[..., 0] = False
|
||||
sampling_mask[..., -1] = False
|
||||
|
||||
sampled_tokens, logprobs = sample_function_logits(
|
||||
logits,
|
||||
function_tokenizer,
|
||||
top_p=sampling_track_config.top_p,
|
||||
temperature=sampling_track_config.temperature,
|
||||
)
|
||||
if sampling_track_config.only_sample_masked_tokens:
|
||||
is_mask = torch.all(
|
||||
tokens == function_tokenizer.mask_token_id, dim=-1
|
||||
) # (B, L)
|
||||
sampling_mask = sampling_mask & is_mask
|
||||
|
||||
sampled_tokens = torch.where(
|
||||
sampling_mask[..., None].expand_as(sampled_tokens), sampled_tokens, tokens
|
||||
) # (B, L, D)
|
||||
|
||||
# Set logprobs for non-sampled tokens to 0
|
||||
logprobs_null = torch.full_like(logprobs, -torch.inf) # (B, L, D, V)
|
||||
logprobs_null = torch.scatter(
|
||||
logprobs_null, -1, tokens[..., None], torch.zeros_like(logprobs_null)[..., [0]]
|
||||
)
|
||||
logprobs = torch.where(
|
||||
sampling_mask[..., None, None].expand_as(logprobs), logprobs, logprobs_null
|
||||
) # (B, L, D, V)
|
||||
|
||||
function_metadata = _compute_track_metadata(
|
||||
sampled_tokens,
|
||||
logprobs,
|
||||
sampling_mask,
|
||||
top_k=sampling_track_config.topk_logprobs,
|
||||
)
|
||||
# Consider the entropy of the joint distribution of all function tokens at each position
|
||||
function_metadata["entropy"] = function_metadata["entropy"].sum(
|
||||
-1
|
||||
) # (B, L, D) -> (B, L)
|
||||
return function_metadata
|
||||
|
||||
|
||||
def _compute_track_metadata(
|
||||
sampled_tokens: torch.Tensor,
|
||||
log_probs: torch.Tensor,
|
||||
sampling_mask: torch.Tensor,
|
||||
top_k: int,
|
||||
) -> dict:
|
||||
"""Works with inputs that have batch dimension."""
|
||||
probs = torch.exp(log_probs) # (B, L)
|
||||
entropy = torch.distributions.Categorical(logits=log_probs).entropy() # (B, L)
|
||||
|
||||
# Only compute probabilities for sampled tokens
|
||||
sampled_logprob = torch.zeros_like(sampled_tokens, dtype=log_probs.dtype) # (B, L)
|
||||
|
||||
if sampled_tokens.dim() > sampling_mask.dim():
|
||||
assert sampled_tokens.dim() == 3 # (B, L, D)
|
||||
assert sampling_mask.dim() == 2 # (B, L)
|
||||
sampling_mask = sampling_mask[..., None].expand_as(sampled_tokens)
|
||||
|
||||
sampled_tokens_valid = sampled_tokens[sampling_mask]
|
||||
sampled_log_probs_valid = log_probs[sampling_mask, sampled_tokens_valid]
|
||||
sampled_logprob[sampling_mask] = sampled_log_probs_valid
|
||||
|
||||
# Calculate extra metadata
|
||||
sampled_prob = torch.exp(sampled_logprob)
|
||||
top_prob = torch.max(probs, dim=-1).values
|
||||
topk_logprobs, topk_tokens = torch.topk(log_probs, top_k, dim=-1)
|
||||
topk_logprobs = None if top_k == 0 else topk_logprobs
|
||||
topk_tokens = None if top_k == 0 else topk_tokens
|
||||
|
||||
return {
|
||||
"entropy": entropy,
|
||||
"sampled_tokens": sampled_tokens,
|
||||
"prob": sampled_prob,
|
||||
"logprob": sampled_logprob,
|
||||
"top_prob": top_prob,
|
||||
"topk_logprob": topk_logprobs,
|
||||
"topk_tokens": topk_tokens,
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from esm.sdk.api import (
|
||||
ESMProteinTensor,
|
||||
SamplingConfig,
|
||||
SamplingTrackConfig,
|
||||
)
|
||||
@@ -13,7 +14,92 @@ from esm.tokenization import (
|
||||
from esm.tokenization.function_tokenizer import (
|
||||
InterProQuantizedTokenizer,
|
||||
)
|
||||
from esm.utils.constants.esm3 import MAX_RESIDUE_ANNOTATIONS
|
||||
from esm.utils.constants.esm3 import (
|
||||
MAX_RESIDUE_ANNOTATIONS,
|
||||
SASA_DISCRETIZATION_BOUNDARIES,
|
||||
)
|
||||
|
||||
# Number of dimensions for each protein tensor field without the batch dimension.
|
||||
_DIMS: dict[str, int] = {
|
||||
"sequence": 1,
|
||||
"structure": 1,
|
||||
"secondary_structure": 1,
|
||||
"sasa": 1,
|
||||
"function": 2,
|
||||
"residue_annotations": 2,
|
||||
"coordinates": 3,
|
||||
}
|
||||
|
||||
|
||||
class _BatchedESMProteinTensor(ESMProteinTensor):
|
||||
@staticmethod
|
||||
def from_protein_tensor(protein: ESMProteinTensor):
|
||||
def _maybe_unsqueeze(x: torch.Tensor | None):
|
||||
return x.unsqueeze(0) if x is not None else None
|
||||
|
||||
return _BatchedESMProteinTensor(
|
||||
sequence=_maybe_unsqueeze(protein.sequence),
|
||||
structure=_maybe_unsqueeze(protein.structure),
|
||||
secondary_structure=_maybe_unsqueeze(protein.secondary_structure),
|
||||
sasa=_maybe_unsqueeze(protein.sasa),
|
||||
function=_maybe_unsqueeze(protein.function),
|
||||
residue_annotations=_maybe_unsqueeze(protein.residue_annotations),
|
||||
coordinates=_maybe_unsqueeze(protein.coordinates),
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
def get_len(k, v) -> int:
|
||||
assert len(v.shape) == _DIMS[k] + 1
|
||||
return v.size(1)
|
||||
|
||||
l = self._detect_attribute(get_len, "length")
|
||||
return l if l is not None else 0
|
||||
|
||||
@property
|
||||
def batch_size(self) -> int:
|
||||
def get_batch_size(k, v) -> int:
|
||||
assert len(v.shape) == _DIMS[k] + 1
|
||||
return v.size(0)
|
||||
|
||||
d = self._detect_attribute(get_batch_size, "batch size")
|
||||
assert d is not None
|
||||
return d
|
||||
|
||||
def slice(
|
||||
self,
|
||||
i: int,
|
||||
sequence_len: int | None = None,
|
||||
) -> ESMProteinTensor:
|
||||
def _maybe_slice(x: torch.Tensor | None):
|
||||
if x is None:
|
||||
return None
|
||||
row = x[i]
|
||||
if sequence_len is not None:
|
||||
row = row[:sequence_len]
|
||||
return row
|
||||
|
||||
return ESMProteinTensor(
|
||||
sequence=_maybe_slice(self.sequence),
|
||||
structure=_maybe_slice(self.structure),
|
||||
secondary_structure=_maybe_slice(self.secondary_structure),
|
||||
sasa=_maybe_slice(self.sasa),
|
||||
function=_maybe_slice(self.function),
|
||||
residue_annotations=_maybe_slice(self.residue_annotations),
|
||||
coordinates=_maybe_slice(self.coordinates),
|
||||
)
|
||||
|
||||
def set_slice(self, i: int, slice: ESMProteinTensor):
|
||||
"""Update the i-th slice of this tensor data class."""
|
||||
for f in attr.fields(ESMProteinTensor):
|
||||
s = getattr(self, f.name)
|
||||
v = getattr(slice, f.name)
|
||||
|
||||
assert v is None or (
|
||||
v is not None and s is not None
|
||||
), f"Trying to set a slice on None tensor ({f.name})."
|
||||
|
||||
if v is not None:
|
||||
s[i, ...] = v
|
||||
|
||||
|
||||
def get_default_sampling_config(tokenizers: TokenizerCollection) -> SamplingConfig:
|
||||
@@ -79,7 +165,8 @@ def sample_function_logits(
|
||||
temperature: float | torch.Tensor = 1.0,
|
||||
p_none_threshold: float = 0.05,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
[L, D, V] = logits.shape
|
||||
"""Works with inputs that have batch dimension."""
|
||||
[B, L, D, V] = logits.shape
|
||||
assert D == tokenizer.depth
|
||||
|
||||
if top_p < 1.0:
|
||||
@@ -87,18 +174,26 @@ def sample_function_logits(
|
||||
|
||||
temperature = torch.ones_like(logits[..., 0]) * temperature
|
||||
|
||||
log_p = F.log_softmax(logits / temperature[..., None], dim=-1) # (L, D, V)
|
||||
log_p = F.log_softmax(logits / temperature[..., None], dim=-1) # (B, L, D, V)
|
||||
|
||||
# Choose which positions have no predicted function.
|
||||
log_p_nones = log_p[..., tokenizer.vocab_to_index["<none>"]] # (L, D)
|
||||
none_index = tokenizer.vocab_to_index["<none>"]
|
||||
log_p_nones = log_p[..., none_index] # (B, L, D)
|
||||
p_none = torch.exp(log_p_nones).mean(dim=-1) # "Ensemble of <none> predictions"
|
||||
where_none = p_none > p_none_threshold # (L, )
|
||||
where_none = p_none > p_none_threshold # (B, L)
|
||||
|
||||
# Set probability of <none> to 0 for all not-none positions
|
||||
none_index = tokenizer.vocab_to_index["<none>"]
|
||||
log_p[~where_none, :, none_index] = -torch.inf
|
||||
batch_size, seq_len, depth = log_p.shape[:-1]
|
||||
expanded_where_not_none = ~where_none.unsqueeze(-1).unsqueeze(-1) # (B, L, 1, 1)
|
||||
expanded_where_not_none = expanded_where_not_none.expand(
|
||||
batch_size, seq_len, depth, 1
|
||||
) # (B, L, D, 1)
|
||||
indices = torch.arange(log_p.shape[-1], device=log_p.device) # (V,)
|
||||
mask = indices == none_index # (V,)
|
||||
mask = expanded_where_not_none & mask # (B, L, D, 1) x (V,) -> (B, L, D, V)
|
||||
log_p[mask] = -torch.inf
|
||||
|
||||
ids = torch.argmax(log_p, dim=-1) # (L, D)
|
||||
ids = torch.argmax(log_p, dim=-1) # (B, L, D)
|
||||
ids[where_none, :] = tokenizer.vocab_to_index["<none>"]
|
||||
|
||||
return ids, log_p
|
||||
@@ -110,10 +205,10 @@ def sample_residue_annotation_logits(
|
||||
# Take top residue annotations
|
||||
top_residue_annotations_idx = logits.argsort(dim=-1, descending=True)[
|
||||
..., :MAX_RESIDUE_ANNOTATIONS
|
||||
] # (L, MAX_R)
|
||||
] # (B, L, MAX_R)
|
||||
top_residue_annotations_logprobs = torch.gather(
|
||||
F.logsigmoid(logits), -1, top_residue_annotations_idx
|
||||
) # (L, MAX_R)
|
||||
) # (B, L, MAX_R)
|
||||
top_residue_annotations_probs = top_residue_annotations_logprobs.exp()
|
||||
# Keep only positive predictions
|
||||
is_negative = top_residue_annotations_probs < annotation_threshold
|
||||
@@ -124,6 +219,26 @@ def sample_residue_annotation_logits(
|
||||
return top_residue_annotations_idx, top_residue_annotations_logprobs
|
||||
|
||||
|
||||
def sample_sasa_logits(
|
||||
logits: torch.Tensor,
|
||||
tokens: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
sasa_probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
max_prob_idx = torch.argmax(sasa_probs, dim=-1)
|
||||
sasa_bins = torch.tensor([0] + SASA_DISCRETIZATION_BOUNDARIES, dtype=torch.float)
|
||||
sasa_bins = (sasa_bins[:-1] + sasa_bins[1:]) / 2
|
||||
sasa_bins = sasa_bins.to(sasa_probs.device)
|
||||
|
||||
# Adjust sasa_values based on max_prob_idx conditions
|
||||
sasa_value = torch.sum(sasa_probs[..., 3:-1] * sasa_bins, dim=-1)
|
||||
sasa_value[tokens == 0] = float("-inf")
|
||||
sasa_value[tokens == 1] = float("-inf")
|
||||
sasa_value[tokens == 2] = float("-inf")
|
||||
sasa_value[max_prob_idx == 18] = float("inf")
|
||||
|
||||
return sasa_value
|
||||
|
||||
|
||||
def top_p_logits(
|
||||
logits: torch.Tensor,
|
||||
top_p: float | torch.Tensor,
|
||||
|
||||
@@ -9,11 +9,6 @@ from typing_extensions import Self
|
||||
from esm.utils.misc import fp32_autocast_context
|
||||
|
||||
|
||||
def maybe_compile(func, x: torch.Tensor):
|
||||
# Sometimes, torch compile seems to give issues for CPU tensors...
|
||||
return torch.compile(func) if x.device.type == "cuda" else func
|
||||
|
||||
|
||||
@T.runtime_checkable
|
||||
class Rotation(T.Protocol):
|
||||
@classmethod
|
||||
@@ -154,9 +149,7 @@ class RotationMatrix(Rotation):
|
||||
x_axis: torch.Tensor, xy_plane: torch.Tensor, eps: float = 1e-12
|
||||
) -> RotationMatrix:
|
||||
# A low eps here is necessary for good stability!
|
||||
return RotationMatrix(
|
||||
maybe_compile(_graham_schmidt, x_axis)(x_axis, xy_plane, eps)
|
||||
)
|
||||
return RotationMatrix(_graham_schmidt(x_axis, xy_plane, eps))
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
||||
@@ -17,6 +17,7 @@ from biotite.application.dssp import DsspApp
|
||||
from biotite.database import rcsb
|
||||
from biotite.structure.io.npz import NpzFile
|
||||
from biotite.structure.io.pdb import PDBFile
|
||||
from cloudpathlib import CloudPath
|
||||
from scipy.spatial.distance import pdist, squareform
|
||||
from torch import Tensor
|
||||
|
||||
@@ -38,7 +39,7 @@ CHAIN_ID_CONST = "A"
|
||||
|
||||
|
||||
ArrayOrTensor = TypeVar("ArrayOrTensor", np.ndarray, Tensor)
|
||||
PathLike = Union[str, Path]
|
||||
PathLike = Union[str, Path, CloudPath]
|
||||
PathOrBuffer = Union[PathLike, io.StringIO]
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,9 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
PathLike = Union[str, Path]
|
||||
from cloudpathlib import CloudPath
|
||||
|
||||
PathLike = Union[str, Path, CloudPath]
|
||||
PathOrBuffer = Union[PathLike, io.StringIO]
|
||||
|
||||
|
||||
|
||||
801
examples/esmprotein.ipynb
Normal file
801
examples/esmprotein.ipynb
Normal file
@@ -0,0 +1,801 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Input tracks of `ESMProtein`\n",
|
||||
"\n",
|
||||
"ESM3 is a frontier generative model for biology, able to jointly reason across three fundamental biological properties of proteins: sequence, structure, and function. These three data modalities are represented as tracks of discrete tokens at the input and output of ESM3. You can present the model with a combination of partial inputs across the tracks, and ESM3 will provide output predictions for all the tracks."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In this notebook, we will familiarize ourselves with the `ESMProtein` class, which holds multiple properties of a protein representing sequence, structure, and function. The ESM3 models use these properties from the input (prompts) and generate them as part of the output."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"An `ESMProtein` has 5 attributes that represent input (promptable) tracks:\n",
|
||||
"\n",
|
||||
"* `sequence`: amino acid sequence\n",
|
||||
"* `coordinates`: 3D coordinates of atoms in each amino acid of the protein\n",
|
||||
"* `secondary_structure`: [8-class secondary structure](https://en.wikipedia.org/wiki/Protein_secondary_structure#DSSP_classification) (SS8)\n",
|
||||
"* `sasa`: [solvent-accessible surface area](https://en.wikipedia.org/wiki/Accessible_surface_area) (SASA)\n",
|
||||
"* `function_annotations`: function annotations derived from [InterPro](https://www.ebi.ac.uk/interpro/)\n",
|
||||
"\n",
|
||||
"You can prompt an ESM3 model by setting any subset of these tracks to be partially unmasked when calling the model with an `ESMProtein` instance."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"One way to create an `ESMProtein` object is from a pdb id and chain id from [RCSB](https://www.rcsb.org). Below, we first create a `ProteinChain` with the pdb id and chain id and then create an `ESMProtein` from it. This will populate the `sequence` and `coordinates` properties."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Install esm\n",
|
||||
"! pip install esm"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from biotite.database import rcsb\n",
|
||||
"from esm.utils.structure.protein_chain import ProteinChain\n",
|
||||
"from esm.sdk.api import ESMProtein\n",
|
||||
"from esm.utils.types import FunctionAnnotation\n",
|
||||
"\n",
|
||||
"pdb_id = \"1cm4\"\n",
|
||||
"chain_id = \"A\"\n",
|
||||
"\n",
|
||||
"# Create a protein using a pdb format file from RCSB\n",
|
||||
"# Note: instead of the next two lines, we could use\n",
|
||||
"# protein_chain = ProteinChain.from_rcsb(pdb_id, chain_id)\n",
|
||||
"# but in future implementations, this function may use the mmcif file\n",
|
||||
"# which would throw off some indices later on in this notebook\n",
|
||||
"str_io = rcsb.fetch(pdb_id, \"pdb\")\n",
|
||||
"protein_chain = ProteinChain.from_pdb(str_io, chain_id=chain_id, id=pdb_id)\n",
|
||||
"protein = ESMProtein.from_protein_chain(protein_chain)\n",
|
||||
"\n",
|
||||
"## We can also load from a local pdb file by passing its path\n",
|
||||
"# protein_chain = ProteinChain.from_pdb('xxxx.pdb', chain_id=chain_id, id=pdb_id)\n",
|
||||
"# The chain_id and id arguments are optional and will be inferred if None"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### `sequence`\n",
|
||||
"The `sequence` track contains a sequence of 1-letter representation of the amino acids in the protein:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(protein.sequence)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### `coordinates`\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"`coordinates` contains the 3D coordinates of atoms in the protein. It contains a tensor of shape `(n_residues, 37, 3)`, where \n",
|
||||
"\n",
|
||||
"* `n_residues` is the number of amino acids in the protein.\n",
|
||||
"* `37` is the maximum possible number of atoms in an amino acid, represented in the atom37 representation. If certain atoms are not present in the structure, they will show up as `nan`.\n",
|
||||
"* `3` is for 3D (x,y,z) coordinates. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(protein.coordinates.shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(protein.coordinates)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We define two functions below that visualize the `coordinates` attribute: we define two functions below (as before, there is no need to go through them)\n",
|
||||
"* `visualize_3D_coordinates()` visualizes directly from the coordinates tensor by creating a pdb file with all alanines\n",
|
||||
"* `visualize_3D_protein()` visualizes from the `ESMProtein` instance, which has the correct amino acids"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\"\"\"\n",
|
||||
"Functions for visualizing 3D structure\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"! pip install py3Dmol\n",
|
||||
"\n",
|
||||
"import py3Dmol\n",
|
||||
"\n",
|
||||
"def visualize_pdb(pdb_string):\n",
|
||||
" view = py3Dmol.view(width=400, height=400)\n",
|
||||
" view.addModel(pdb_string, \"pdb\")\n",
|
||||
" view.setStyle({'cartoon': {'color': 'spectrum'}})\n",
|
||||
" view.zoomTo()\n",
|
||||
" view.render()\n",
|
||||
" view.center()\n",
|
||||
" return view\n",
|
||||
"\n",
|
||||
"def visualize_3D_coordinates(coordinates):\n",
|
||||
" \"\"\"\n",
|
||||
" This uses all Alanines\n",
|
||||
" \"\"\"\n",
|
||||
" protein_with_same_coords = ESMProtein(coordinates=coordinates)\n",
|
||||
" # pdb with all alanines\n",
|
||||
" pdb_string = protein_with_same_coords.to_pdb_string()\n",
|
||||
" return visualize_pdb(pdb_string)\n",
|
||||
"\n",
|
||||
"def visualize_3D_protein(protein):\n",
|
||||
" pdb_string = protein.to_pdb_string()\n",
|
||||
" return visualize_pdb(pdb_string)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# visualize from just the coordinates\n",
|
||||
"visualize_3D_coordinates(protein.coordinates)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# visualize from sequence and coordinates\n",
|
||||
"visualize_3D_protein(protein)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### `secondary_structure`\n",
|
||||
"\n",
|
||||
"The `secondary_structure` property contains a representation of the secondary structure. At a high level of categorization, we can classify each amino acid as belonging into three classes: alpha helices, beta sheets, and coil, which we could see in the previous 3D visualization.\n",
|
||||
"\n",
|
||||
"`ESMProtein` uses a [8-class secondary structure](https://en.wikipedia.org/wiki/Protein_secondary_structure#DSSP_classification) that can be computed with [dssp](https://swift.cmbi.umcn.nl/gv/dssp/) given 3D atom coordinates. Since installing dssp is a separate process from installing the `esm` package, in this notebook, we show how to compute the coarser 3-class classification using biotite's [annotate_sse](https://www.biotite-python.org/apidoc/biotite.structure.annotate_sse.html). We can set the `secondary_structure` property with this 3-class classification."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from biotite.structure import annotate_sse\n",
|
||||
"\n",
|
||||
"def get_approximate_ss(protein_chain: ProteinChain):\n",
|
||||
" # get biotite's ss3 representation\n",
|
||||
" ss3_arr = annotate_sse(protein_chain.atom_array)\n",
|
||||
" biotite_ss3_str = ''.join(ss3_arr)\n",
|
||||
"\n",
|
||||
" # translate into ESM3's representation\n",
|
||||
" translation_table = str.maketrans({\n",
|
||||
" 'a': 'H', # alpha helix\n",
|
||||
" 'b': 'E', # beta sheet\n",
|
||||
" 'c': 'C', # coil\n",
|
||||
" })\n",
|
||||
" esm_ss3 = biotite_ss3_str.translate(translation_table)\n",
|
||||
" return esm_ss3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"protein.secondary_structure = get_approximate_ss(protein_chain)\n",
|
||||
"print(protein.secondary_structure)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The next cell defines a function that visualizes the secondary structure and there is no need to read them!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# ! pip install matplotlib\n",
|
||||
"\n",
|
||||
"# Slightly modified version of secondary structure plotting code from\n",
|
||||
"# https://www.biotite-python.org/examples/gallery/structure/transketolase_sse.html\n",
|
||||
"# Code source: Patrick Kunzmann\n",
|
||||
"# License: BSD 3 clause\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from matplotlib.patches import Rectangle\n",
|
||||
"import biotite\n",
|
||||
"import biotite.sequence as seq\n",
|
||||
"import biotite.sequence.graphics as graphics\n",
|
||||
"\n",
|
||||
"# Create 'FeaturePlotter' subclasses\n",
|
||||
"# for drawing the secondary structure features\n",
|
||||
"class HelixPlotter(graphics.FeaturePlotter):\n",
|
||||
"\n",
|
||||
" def __init__(self):\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
" # Check whether this class is applicable for drawing a feature\n",
|
||||
" def matches(self, feature):\n",
|
||||
" if feature.key == \"SecStr\":\n",
|
||||
" if \"sec_str_type\" in feature.qual:\n",
|
||||
" if feature.qual[\"sec_str_type\"] == \"helix\":\n",
|
||||
" return True\n",
|
||||
" return False\n",
|
||||
"\n",
|
||||
" # The drawing function itself\n",
|
||||
" def draw(self, axes, feature, bbox, loc, style_param):\n",
|
||||
" # Approx. 1 turn per 3.6 residues to resemble natural helix\n",
|
||||
" n_turns = np.ceil((loc.last - loc.first + 1) / 3.6)\n",
|
||||
" x_val = np.linspace(0, n_turns * 2*np.pi, 100)\n",
|
||||
" # Curve ranges from 0.3 to 0.7\n",
|
||||
" y_val = (-0.4*np.sin(x_val) + 1) / 2\n",
|
||||
"\n",
|
||||
" # Transform values for correct location in feature map\n",
|
||||
" x_val *= bbox.width / (n_turns * 2*np.pi)\n",
|
||||
" x_val += bbox.x0\n",
|
||||
" y_val *= bbox.height\n",
|
||||
" y_val += bbox.y0\n",
|
||||
"\n",
|
||||
" # Draw white background to overlay the guiding line\n",
|
||||
" background = Rectangle(\n",
|
||||
" bbox.p0, bbox.width, bbox.height, color=\"white\", linewidth=0\n",
|
||||
" )\n",
|
||||
" axes.add_patch(background)\n",
|
||||
" axes.plot(\n",
|
||||
" x_val, y_val, linewidth=2, color=biotite.colors[\"dimgreen\"]\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class SheetPlotter(graphics.FeaturePlotter):\n",
|
||||
"\n",
|
||||
" def __init__(self, head_width=0.8, tail_width=0.5):\n",
|
||||
" self._head_width = head_width\n",
|
||||
" self._tail_width = tail_width\n",
|
||||
"\n",
|
||||
" def matches(self, feature):\n",
|
||||
" if feature.key == \"SecStr\":\n",
|
||||
" if \"sec_str_type\" in feature.qual:\n",
|
||||
" if feature.qual[\"sec_str_type\"] == \"sheet\":\n",
|
||||
" return True\n",
|
||||
" return False\n",
|
||||
"\n",
|
||||
" def draw(self, axes, feature, bbox, loc, style_param):\n",
|
||||
" x = bbox.x0\n",
|
||||
" y = bbox.y0 + bbox.height/2\n",
|
||||
" dx = bbox.width\n",
|
||||
" dy = 0\n",
|
||||
"\n",
|
||||
" if loc.defect & seq.Location.Defect.MISS_RIGHT:\n",
|
||||
" # If the feature extends into the previous or next line\n",
|
||||
" # do not draw an arrow head\n",
|
||||
" draw_head = False\n",
|
||||
" else:\n",
|
||||
" draw_head = True\n",
|
||||
"\n",
|
||||
" axes.add_patch(biotite.AdaptiveFancyArrow(\n",
|
||||
" x, y, dx, dy,\n",
|
||||
" self._tail_width*bbox.height, self._head_width*bbox.height,\n",
|
||||
" # Create head with 90 degrees tip\n",
|
||||
" # -> head width/length ratio = 1/2\n",
|
||||
" head_ratio=0.5, draw_head=draw_head,\n",
|
||||
" color=biotite.colors[\"orange\"], linewidth=0\n",
|
||||
" ))\n",
|
||||
"\n",
|
||||
"# Converter for the DSSP secondary structure elements\n",
|
||||
"# to the classical ones\n",
|
||||
"dssp_to_abc = {\"I\" : \"c\",\n",
|
||||
" \"S\" : \"c\",\n",
|
||||
" \"H\" : \"a\",\n",
|
||||
" \"E\" : \"b\",\n",
|
||||
" \"G\" : \"c\",\n",
|
||||
" \"B\" : \"b\",\n",
|
||||
" \"T\" : \"c\",\n",
|
||||
" \"C\" : \"c\"}\n",
|
||||
"\n",
|
||||
"def visualize_secondary_structure(sse, first_id):\n",
|
||||
" \"\"\"\n",
|
||||
" Helper function to convert secondary structure array to annotation\n",
|
||||
" and visualize it.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" def _add_sec_str(annotation, first, last, str_type):\n",
|
||||
" if str_type == \"a\":\n",
|
||||
" str_type = \"helix\"\n",
|
||||
" elif str_type == \"b\":\n",
|
||||
" str_type = \"sheet\"\n",
|
||||
" else:\n",
|
||||
" # coil\n",
|
||||
" return\n",
|
||||
" feature = seq.Feature(\n",
|
||||
" \"SecStr\", [seq.Location(first, last)], {\"sec_str_type\" : str_type}\n",
|
||||
" )\n",
|
||||
" annotation.add_feature(feature)\n",
|
||||
"\n",
|
||||
" # Find the intervals for each secondary structure element\n",
|
||||
" # and add to annotation\n",
|
||||
" annotation = seq.Annotation()\n",
|
||||
" curr_sse = None\n",
|
||||
" curr_start = None\n",
|
||||
" for i in range(len(sse)):\n",
|
||||
" if curr_start is None:\n",
|
||||
" curr_start = i\n",
|
||||
" curr_sse = sse[i]\n",
|
||||
" else:\n",
|
||||
" if sse[i] != sse[i-1]:\n",
|
||||
" _add_sec_str(\n",
|
||||
" annotation, curr_start+first_id, i-1+first_id, curr_sse\n",
|
||||
" )\n",
|
||||
" curr_start = i\n",
|
||||
" curr_sse = sse[i]\n",
|
||||
" # Add last secondary structure element to annotation\n",
|
||||
" _add_sec_str(annotation, curr_start+first_id, i+first_id, curr_sse)\n",
|
||||
"\n",
|
||||
" fig = plt.figure(figsize=(30.0, 3.0))\n",
|
||||
" ax = fig.add_subplot(111)\n",
|
||||
" graphics.plot_feature_map(\n",
|
||||
" ax, annotation, symbols_per_line=150,\n",
|
||||
" loc_range=(first_id, first_id+len(sse)),\n",
|
||||
" feature_plotters=[HelixPlotter(), SheetPlotter()]\n",
|
||||
" )\n",
|
||||
" fig.tight_layout()\n",
|
||||
" return fig, ax\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def plot_ss8(ss8_string):\n",
|
||||
" ss3 = np.array([dssp_to_abc[e] for e in ss8_string], dtype=\"U1\")\n",
|
||||
" _, ax = visualize_secondary_structure(ss3, 1)\n",
|
||||
" ax.set_xticks([])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Using these functions, we can visualize the secondary structure that we obtained. The alpha helices are represented in green, the beta sheets are represented in orange, and coils are represented by gray lines. \n",
|
||||
"\n",
|
||||
"Note: because the secondary structure assignment algorithm is not the same one as the one used by 3D visualization, this differs a bit from the cartoon representations in the 3D assignment."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plot_ss8(protein.secondary_structure)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### `function_annotations`\n",
|
||||
"\n",
|
||||
"An `ESMProtein` also contains function annotations derived from [InterPro](https://www.ebi.ac.uk/interpro/). Annotations directly from InterPro contain information about the following [entry types](https://interpro-documentation.readthedocs.io/en/latest/faq.html#what-are-entry-types):\n",
|
||||
"* Family\n",
|
||||
"* Domain\n",
|
||||
"* Homologous superfamily\n",
|
||||
"* Repeat\n",
|
||||
"* Site (conserved site, active site, binding site, post-translational modification site)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from esm.utils.types import FunctionAnnotation\n",
|
||||
"\n",
|
||||
"interpro_function_annotations = [\n",
|
||||
" FunctionAnnotation(label=\"IPR050145\", start=1, end=142), # 1 indexed, inclusive;\n",
|
||||
" FunctionAnnotation(label=\"IPR002048\", start=4, end=75),\n",
|
||||
" FunctionAnnotation(label=\"IPR002048\", start=77, end=144),\n",
|
||||
" FunctionAnnotation(label=\"IPR011992\", start=1, end=143),\n",
|
||||
" FunctionAnnotation(label=\"IPR018247\", start=17, end=29),\n",
|
||||
" FunctionAnnotation(label=\"IPR018247\", start=53, end=65),\n",
|
||||
" FunctionAnnotation(label=\"IPR018247\", start=90, end=102),\n",
|
||||
" FunctionAnnotation(label=\"IPR018247\", start=126, end=138),\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can visualize these InterPro annotations with the following function:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\"\"\"\n",
|
||||
"Functions for visualizing InterPro function annotations\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"! pip install dna-features-viewer\n",
|
||||
"\n",
|
||||
"from dna_features_viewer import GraphicFeature, GraphicRecord\n",
|
||||
"from esm.utils.function.interpro import InterProEntryType, InterPro\n",
|
||||
"from matplotlib import colormaps\n",
|
||||
"\n",
|
||||
"def visualize_function_annotations(\n",
|
||||
" annotations: list[FunctionAnnotation],\n",
|
||||
" sequence_length: int,\n",
|
||||
" ax: plt.Axes,\n",
|
||||
" interpro_ = InterPro(),\n",
|
||||
"):\n",
|
||||
" cmap = colormaps[\"tab10\"]\n",
|
||||
" colors = [cmap(i) for i in range(len(InterProEntryType))]\n",
|
||||
" type_colors = dict(zip(InterProEntryType, colors))\n",
|
||||
"\n",
|
||||
" features = []\n",
|
||||
" for annotation in annotations:\n",
|
||||
" if annotation.label in interpro_.entries:\n",
|
||||
" entry = interpro_.entries[annotation.label]\n",
|
||||
" label = entry.name\n",
|
||||
" entry_type = entry.type\n",
|
||||
" else:\n",
|
||||
" label = annotation.label\n",
|
||||
" entry_type = InterProEntryType.UNKNOWN\n",
|
||||
"\n",
|
||||
" feature = GraphicFeature(\n",
|
||||
" start=annotation.start - 1, # one index -> zero index\n",
|
||||
" end=annotation.end,\n",
|
||||
" label=label,\n",
|
||||
" color=type_colors[entry_type],\n",
|
||||
" strand=None,\n",
|
||||
" )\n",
|
||||
" features.append(feature)\n",
|
||||
"\n",
|
||||
" record = GraphicRecord(\n",
|
||||
" sequence=None,\n",
|
||||
" sequence_length=sequence_length,\n",
|
||||
" features=features,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" record.plot(figure_width=12, plot_sequence=False, ax=ax)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We plot the InterPro annotations below, with colors indicating the entry type of the InterPro annotation "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fig, ax = plt.subplots(figsize=(20.0, 4.0))\n",
|
||||
"visualize_function_annotations(\n",
|
||||
" interpro_function_annotations,\n",
|
||||
" len(protein),\n",
|
||||
" ax,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"When using our `ESM3` model, we recommend you use keyword annotations, which are keywords in the description of the InterPro entry and associated Gene Ontology terms from [InterPro2GO](https://www.ebi.ac.uk/GOA/InterPro2GO). For instance, for the InterPro entry [IPR011992](https://www.ebi.ac.uk/interpro/entry/InterPro/IPR011992/), the keywords are \"domain pair\", \"hand domain\", \"ef hand\", \"pair\", and \"ef\". For more details regarding how the keywords were computed, please refer to our preprint.\n",
|
||||
"\n",
|
||||
"Practically, we can derive keyword annotations from the InterPro annotations with the function below. Each InterPro annotation corresponds to multiple keyword annotation covering the same range."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from esm.tokenization import InterProQuantizedTokenizer\n",
|
||||
"\n",
|
||||
"def get_keywords_from_interpro(\n",
|
||||
" interpro_annotations,\n",
|
||||
" interpro2keywords=InterProQuantizedTokenizer().interpro2keywords,\n",
|
||||
"):\n",
|
||||
" keyword_annotations_list = []\n",
|
||||
" for interpro_annotation in interpro_annotations:\n",
|
||||
" keywords = interpro2keywords.get(interpro_annotation.label, [])\n",
|
||||
" keyword_annotations_list.extend([\n",
|
||||
" FunctionAnnotation(\n",
|
||||
" label=keyword,\n",
|
||||
" start=interpro_annotation.start,\n",
|
||||
" end=interpro_annotation.end,\n",
|
||||
" )\n",
|
||||
" for keyword in keywords\n",
|
||||
" ])\n",
|
||||
" return keyword_annotations_list"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"protein.function_annotations = get_keywords_from_interpro(interpro_function_annotations)\n",
|
||||
"protein.function_annotations"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can also visualize the keyword annotations, which all have the same color, indicating it is not a known InterPro entry type."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fig, ax = plt.subplots(figsize=(20.0, 8.0))\n",
|
||||
"visualize_function_annotations(\n",
|
||||
" protein.function_annotations,\n",
|
||||
" len(protein),\n",
|
||||
" ax,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### `sasa`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The final input track of `ESMProtein` is the solvent-accessible surface area, or [SASA](https://en.wikipedia.org/wiki/Accessible_surface_area). For each amino acid, this track indicates how much of it is accessible to the solvent. We can compute this by `ProteinChain`'s `sasa` function, which uses biotite's [`sasa`](https://www.biotite-python.org/apidoc/biotite.structure.sasa.html) function under the hood."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"protein.sasa = protein_chain.sasa()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"One way to visualize this track is to represent its values as it varies along the amino acid sequence."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plt.plot(protein.sasa)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can also map these SASA values onto the 3D visualization of the structure, leveraging the fact that we have this protein's 3D coordinates.\n",
|
||||
"\n",
|
||||
"First we define which colors map to which values:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from matplotlib import colormaps\n",
|
||||
"\n",
|
||||
"cmap = colormaps['cividis']\n",
|
||||
"clip_sasa_lower = 10\n",
|
||||
"clip_sasa_upper = 90\n",
|
||||
"\n",
|
||||
"def plot_heatmap_legend(\n",
|
||||
" cmap,\n",
|
||||
" clip_sasa_lower,\n",
|
||||
" clip_sasa_upper,\n",
|
||||
"):\n",
|
||||
" gradient = np.linspace(0, 1, 256)\n",
|
||||
" gradient = np.vstack((gradient, gradient))\n",
|
||||
" _, ax = plt.subplots(figsize=(5, 0.3), dpi=350)\n",
|
||||
" ax.imshow(gradient, aspect='auto', cmap=cmap)\n",
|
||||
" ax.text(0.1, -0.3, f'{clip_sasa_lower} or lower', va='center', ha='right', fontsize=7, transform=ax.transAxes)\n",
|
||||
" ax.text(0.5, -0.3, f'{(clip_sasa_lower + clip_sasa_upper) // 2}', va='center', ha='right', fontsize=7, transform=ax.transAxes)\n",
|
||||
" ax.text(0.9, -0.3, f'{clip_sasa_upper} or higher', va='center', ha='left', fontsize=7, transform=ax.transAxes)\n",
|
||||
" ax.set_xticklabels([])\n",
|
||||
" ax.set_yticklabels([])\n",
|
||||
" ax.set_xticks([])\n",
|
||||
" ax.set_yticks([])\n",
|
||||
" plt.show()\n",
|
||||
"\n",
|
||||
"plot_heatmap_legend(cmap, clip_sasa_lower, clip_sasa_upper)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\"\"\"\n",
|
||||
"Functions for visualizing SASA as colors on the 3D structure\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"def get_color_strings(\n",
|
||||
" sasa,\n",
|
||||
" clip_sasa_lower,\n",
|
||||
" clip_sasa_upper,\n",
|
||||
" cmap,\n",
|
||||
"):\n",
|
||||
" transformed_sasa = np.clip(sasa, clip_sasa_lower, clip_sasa_upper)\n",
|
||||
" transformed_sasa = (transformed_sasa - clip_sasa_lower) / (clip_sasa_upper - clip_sasa_lower)\n",
|
||||
" rgbas = (cmap(transformed_sasa) * 255).astype(int)\n",
|
||||
"\n",
|
||||
" return [\n",
|
||||
" f'rgb({rgba[0]},{rgba[1]},{rgba[2]})'\n",
|
||||
" for rgba in rgbas\n",
|
||||
" ] \n",
|
||||
"\n",
|
||||
"def visualize_sasa_3D_protein(\n",
|
||||
" protein,\n",
|
||||
" clip_sasa_lower=clip_sasa_lower,\n",
|
||||
" clip_sasa_upper=clip_sasa_upper,\n",
|
||||
" cmap=cmap,\n",
|
||||
"):\n",
|
||||
" pdb_string = protein.to_pdb_string()\n",
|
||||
" plot_heatmap_legend(cmap, clip_sasa_lower, clip_sasa_upper)\n",
|
||||
" view = py3Dmol.view(width=400, height=400)\n",
|
||||
" view.addModel(pdb_string, \"pdb\")\n",
|
||||
"\n",
|
||||
" for res_pos, res_color in enumerate(get_color_strings(protein.sasa, clip_sasa_lower, clip_sasa_upper, cmap)):\n",
|
||||
" view.setStyle({'chain': 'A', 'resi': res_pos+1}, {'cartoon': {'color': res_color}})\n",
|
||||
" view.zoomTo()\n",
|
||||
" view.render()\n",
|
||||
" view.center()\n",
|
||||
"\n",
|
||||
" return view"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We visualize SASA on the 3D structure below. Note that the amino acids that are on the inside have lower SASA values, and the amino acids at the surface have higher SASA values."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"visualize_sasa_3D_protein(protein)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We have now covered all the tracks of `ESMProtein`. \n",
|
||||
"\n",
|
||||
"We can initialize an `ESMProtein` by providing any of these tracks. For instance, to initialize a protein with the same coordinates as our `protein`, we would do:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"same_structure_protein = ESMProtein(coordinates=protein.coordinates)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"and similarly for any other track."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We hope this helps you get started with our ESM3 models!"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
@@ -2,6 +2,8 @@ from esm.models.esm3 import ESM3
|
||||
from esm.sdk.api import (
|
||||
ESM3InferenceClient,
|
||||
ESMProtein,
|
||||
ESMProteinError,
|
||||
ESMProteinTensor,
|
||||
GenerationConfig,
|
||||
SamplingConfig,
|
||||
SamplingTrackConfig,
|
||||
@@ -33,6 +35,20 @@ def main(client: ESM3InferenceClient):
|
||||
single_step_protein.protein_tensor.sequence = protein.sequence
|
||||
single_step_protein = client.decode(single_step_protein.protein_tensor)
|
||||
|
||||
# Generate with partial sequence.
|
||||
prompt = (
|
||||
"___________________________________________________DQATSLRILNNGHAFNVEFDDSQDKAVLK"
|
||||
"GGPLDGTYRLIQFHFHWGSLDGQGSEHTVDKKKYAAELHLVHWNTKYGDFGKAVQQPDGLAVLGIFLKVGSAKPGLQKVVDVLDSIK"
|
||||
"TKGKSADFTNFDPRGLLPESLDYWTYPGSLTTPP___________________________________________________________"
|
||||
)
|
||||
protein = ESMProtein(sequence=prompt)
|
||||
protein = client.generate(
|
||||
protein,
|
||||
GenerationConfig(track="sequence", num_steps=8, temperature=0.7),
|
||||
)
|
||||
assert isinstance(protein, ESMProtein)
|
||||
print(protein.sequence)
|
||||
|
||||
# Folding
|
||||
protein = get_sample_protein()
|
||||
sequence_length = len(protein.sequence) # type: ignore
|
||||
@@ -44,9 +60,10 @@ def main(client: ESM3InferenceClient):
|
||||
protein,
|
||||
GenerationConfig(track="structure", schedule="cosine", num_steps=num_steps),
|
||||
)
|
||||
assert isinstance(folded_protein, ESMProtein)
|
||||
folded_protein.to_pdb("./sample_folded.pdb")
|
||||
|
||||
# Inverse Folding
|
||||
# Inverse folding
|
||||
protein = get_sample_protein()
|
||||
protein.sequence = None
|
||||
protein.sasa = None
|
||||
@@ -55,8 +72,19 @@ def main(client: ESM3InferenceClient):
|
||||
protein,
|
||||
GenerationConfig(track="sequence", schedule="cosine", num_steps=num_steps),
|
||||
)
|
||||
assert isinstance(inv_folded_protein, ESMProtein)
|
||||
inv_folded_protein.to_pdb("./sample_inv_folded.pdb")
|
||||
|
||||
# Function prediction
|
||||
protein = get_sample_protein()
|
||||
protein.function_annotations = None
|
||||
protein_with_function = client.generate(
|
||||
protein,
|
||||
GenerationConfig(track="function", schedule="cosine", num_steps=num_steps),
|
||||
)
|
||||
assert isinstance(protein_with_function, ESMProtein)
|
||||
print(protein_with_function.function_annotations)
|
||||
|
||||
# Chain of Thought (Function -> Secondary Structure -> Structure -> Sequence)
|
||||
cot_protein = get_sample_protein()
|
||||
cot_protein.sequence = "_" * len(cot_protein.sequence) # type: ignore
|
||||
@@ -68,9 +96,51 @@ def main(client: ESM3InferenceClient):
|
||||
cot_protein_tensor,
|
||||
GenerationConfig(track=cot_track, schedule="cosine", num_steps=10),
|
||||
)
|
||||
assert isinstance(cot_protein_tensor, ESMProteinTensor)
|
||||
cot_protein = client.decode(cot_protein_tensor)
|
||||
|
||||
assert isinstance(cot_protein, ESMProtein)
|
||||
cot_protein.to_pdb("./sample_cot.pdb")
|
||||
|
||||
# Batch examples.
|
||||
|
||||
# Batch generation.
|
||||
prompts = [ESMProtein(sequence=("_" * (10 + 2 * i))) for i in range(5)]
|
||||
configs = [
|
||||
GenerationConfig(track="sequence", schedule="cosine", num_steps=(i + 1))
|
||||
for i in range(5)
|
||||
]
|
||||
proteins = client.batch_generate(prompts, configs)
|
||||
|
||||
# Batch folding.
|
||||
# Take the list of proteins batch generated from last step.
|
||||
configs = [
|
||||
GenerationConfig(track="structure", schedule="cosine", num_steps=(i + 1))
|
||||
for i in range(5)
|
||||
]
|
||||
# Generate again for the structure track.
|
||||
proteins = client.batch_generate(proteins, configs)
|
||||
# Now write sequence and structure to PDB files.
|
||||
for i, p in enumerate(proteins):
|
||||
assert isinstance(p, ESMProtein)
|
||||
p.to_pdb(f"./batch_gen_{i}.pdb")
|
||||
|
||||
# Batch generation returns ESMProteinError for specific prompts that have issues.
|
||||
prompts = [ESMProtein(sequence=("_" * (10 + 2 * i))) for i in range(5)]
|
||||
# Mock error situation. The third prompt has no masks to be sampled.
|
||||
prompts[2].sequence = "ANTVPYQ"
|
||||
configs = [
|
||||
GenerationConfig(track="sequence", schedule="cosine", num_steps=(i + 1))
|
||||
for i in range(5)
|
||||
]
|
||||
proteins = client.batch_generate(prompts, configs)
|
||||
# Should still get results. But third result is a ESMProteinError.
|
||||
for i, p in enumerate(proteins):
|
||||
if i == 2:
|
||||
assert isinstance(p, ESMProteinError)
|
||||
else:
|
||||
assert isinstance(p, ESMProtein)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(ESM3.from_pretrained("esm3_sm_open_v1"))
|
||||
|
||||
@@ -34,6 +34,7 @@ dependencies = [
|
||||
"brotli",
|
||||
"attrs",
|
||||
"pandas",
|
||||
"cloudpathlib",
|
||||
]
|
||||
|
||||
[tool.setuptools]
|
||||
|
||||
Reference in New Issue
Block a user