sync 3.2.2.post1 (#270)

This commit is contained in:
Neil Thomas
2025-09-17 14:24:07 -07:00
committed by GitHub
parent 7454c3d77b
commit f4f97929b2
58 changed files with 2513 additions and 141 deletions

View File

@@ -38,7 +38,6 @@
"\n",
"!pip install py3Dmol\n",
"import py3Dmol\n",
"\n",
"from esm.models.esm3 import ESM3\n",
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
"from esm.utils.structure.protein_chain import ProteinChain"

View File

@@ -2,7 +2,6 @@ import random
import torch
import torch.nn.functional as F
from esm.pretrained import (
ESM3_function_decoder_v0,
ESM3_sm_open_v0,
@@ -13,7 +12,9 @@ from esm.tokenization import get_esm3_model_tokenizers
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer as EsmFunctionTokenizer,
)
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
from esm.tokenization.sequence_tokenizer import (
EsmSequenceTokenizer,
)
from esm.utils.structure.protein_chain import ProteinChain
from esm.utils.types import FunctionAnnotation

View File

@@ -2,7 +2,6 @@ import os
from typing import cast
import numpy as np
from esm.sdk.api import (
ESM3InferenceClient,
ESMProtein,

View File

@@ -72,7 +72,6 @@
"outputs": [],
"source": [
"from biotite.database import rcsb\n",
"\n",
"from esm.sdk.api import ESMProtein\n",
"from esm.utils.structure.protein_chain import ProteinChain\n",
"from esm.utils.types import FunctionAnnotation\n",
@@ -497,9 +496,8 @@
"# Functions for visualizing InterPro function annotations\n",
"\n",
"from dna_features_viewer import GraphicFeature, GraphicRecord\n",
"from matplotlib import colormaps\n",
"\n",
"from esm.utils.function.interpro import InterPro, InterProEntryType\n",
"from matplotlib import colormaps\n",
"\n",
"\n",
"def visualize_function_annotations(\n",

View File

@@ -49,18 +49,18 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories."
"Grab a token from [Forge](https://forge.evolutionaryscale.ai/) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories."
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from getpass import getpass\n",
"\n",
"token = getpass(\"Token from Forge console: \")"
"token = getpass(\"Token from Forge: \")"
]
},
{

View File

@@ -64,7 +64,6 @@
"import matplotlib.pyplot as pl\n",
"import py3Dmol\n",
"import torch\n",
"\n",
"from esm.sdk import client\n",
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
"from esm.utils.structure.protein_chain import ProteinChain"
@@ -80,18 +79,18 @@
"\n",
"The largest ESM3 (98 billion parameters) was trained with 1.07e24 FLOPs on 2.78 billion proteins and 771 billion unique tokens. To create esmGFP we used the 7 billion parameter variant of ESM3. We'll use this model via the [EvolutionaryScale Forge](https://forge.evolutionaryscale.ai) API.\n",
"\n",
"Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories.\n"
"Grab a token from [Forge](https://forge.evolutionaryscale.ai/) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories.\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {
"id": "zNrU9Q2SYonX"
},
"outputs": [],
"source": [
"token = getpass(\"Token from Forge console: \")"
"token = getpass(\"Token from Forge: \")"
]
},
{

View File

@@ -36,7 +36,6 @@
"\n",
"!pip install py3Dmol\n",
"import py3Dmol\n",
"\n",
"from esm.sdk import client\n",
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
"from esm.utils.structure.protein_chain import ProteinChain"
@@ -53,7 +52,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories."
"Grab a token from [Forge](https://forge.evolutionaryscale.ai/) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories."
]
},
{
@@ -64,7 +63,7 @@
"source": [
"from getpass import getpass\n",
"\n",
"token = getpass(\"Token from Forge console: \")\n",
"token = getpass(\"Token from Forge: \")\n",
"model = client(model=\"esm3-open\", url=\"https://forge.evolutionaryscale.ai\", token=token)"
]
},

View File

@@ -49,7 +49,6 @@
"source": [
"import biotite.structure as bs\n",
"import py3Dmol\n",
"\n",
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
"from esm.sdk.experimental import ESM3GuidedDecoding, GuidedDecodingScoringFunction"
]
@@ -120,7 +119,7 @@
"\n",
"from esm.sdk import client\n",
"\n",
"token = getpass(\"Token from Forge console: \")\n",
"token = getpass(\"Token from Forge: \")\n",
"model = client(\n",
" model=\"esm3-medium-2024-08\", url=\"https://forge.evolutionaryscale.ai\", token=token\n",
")"

View File

@@ -1 +1,2 @@
__version__ = "3.2.2"
__version__ = "3.2.2.post1"

View File

@@ -5,7 +5,10 @@ import torch
import torch.nn.functional as F
from torch import nn
from esm.layers.rotary import RotaryEmbedding, TritonRotaryEmbedding
from esm.layers.rotary import (
RotaryEmbedding,
TritonRotaryEmbedding,
)
try:
from flash_attn import flash_attn_varlen_qkvpacked_func

View File

@@ -2,8 +2,13 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from esm.layers.attention import FlashMultiHeadAttention, MultiHeadAttention
from esm.layers.geom_attention import GeometricReasoningOriginalImpl
from esm.layers.attention import (
FlashMultiHeadAttention,
MultiHeadAttention,
)
from esm.layers.geom_attention import (
GeometricReasoningOriginalImpl,
)
from esm.utils.structure.affine3d import Affine3D

View File

@@ -2,7 +2,10 @@ import torch
import torch.nn as nn
from esm.utils.constants.physics import BB_COORDINATES
from esm.utils.structure.affine3d import Affine3D, RotationMatrix
from esm.utils.structure.affine3d import (
Affine3D,
RotationMatrix,
)
class Dim6RotStructureHead(nn.Module):

View File

@@ -13,7 +13,10 @@ from attr import dataclass
from esm.layers.regression_head import RegressionHead
from esm.layers.transformer_stack import TransformerStack
from esm.models.function_decoder import FunctionTokenDecoder
from esm.models.vqvae import StructureTokenDecoder, StructureTokenEncoder
from esm.models.vqvae import (
StructureTokenDecoder,
StructureTokenEncoder,
)
from esm.sdk.api import (
ESM3InferenceClient,
ESMProtein,
@@ -29,7 +32,10 @@ from esm.sdk.api import (
from esm.tokenization import TokenizerCollectionProtocol
from esm.utils import encoding
from esm.utils.constants import esm3 as C
from esm.utils.constants.models import ESM3_OPEN_SMALL, normalize_model_name
from esm.utils.constants.models import (
ESM3_OPEN_SMALL,
normalize_model_name,
)
from esm.utils.decoding import decode_protein_tensor
from esm.utils.generation import (
_batch_forward,
@@ -44,7 +50,9 @@ from esm.utils.sampling import (
get_default_sampling_config,
validate_sampling_config,
)
from esm.utils.structure.affine3d import build_affine3d_from_coordinates
from esm.utils.structure.affine3d import (
build_affine3d_from_coordinates,
)
@dataclass

View File

@@ -12,7 +12,9 @@ from cloudpathlib import AnyPath
from esm.layers.regression_head import RegressionHead
from esm.layers.transformer_stack import TransformerStack
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer,
)
from esm.utils.constants import esm3 as C
from esm.utils.misc import merge_annotations, merge_ranges
from esm.utils.types import FunctionAnnotation

View File

@@ -7,7 +7,10 @@ from esm.layers.structure_proj import Dim6RotStructureHead
from esm.layers.transformer_stack import TransformerStack
from esm.utils.constants import esm3 as C
from esm.utils.misc import knn_graph
from esm.utils.structure.affine3d import Affine3D, build_affine3d_from_coordinates
from esm.utils.structure.affine3d import (
Affine3D,
build_affine3d_from_coordinates,
)
from esm.utils.structure.predicted_aligned_error import (
compute_predicted_aligned_error,
compute_tm,

View File

@@ -6,8 +6,14 @@ import torch.nn as nn
from esm.models.esm3 import ESM3
from esm.models.esmc import ESMC
from esm.models.function_decoder import FunctionTokenDecoder
from esm.models.vqvae import StructureTokenDecoder, StructureTokenEncoder
from esm.tokenization import get_esm3_model_tokenizers, get_esmc_model_tokenizers
from esm.models.vqvae import (
StructureTokenDecoder,
StructureTokenEncoder,
)
from esm.tokenization import (
get_esm3_model_tokenizers,
get_esmc_model_tokenizers,
)
from esm.utils.constants.esm3 import data_root
from esm.utils.constants.models import (
ESM3_FUNCTION_DECODER_V0,

View File

@@ -2,19 +2,27 @@ from __future__ import annotations
from abc import ABC
from copy import deepcopy
from typing import Sequence
from typing import List, Sequence
import attr
import torch
from attr import asdict, define
import esm.utils.constants.api as C
from esm.tokenization import TokenizerCollectionProtocol, get_esm3_model_tokenizers
from esm.tokenization import (
TokenizerCollectionProtocol,
get_esm3_model_tokenizers,
)
from esm.utils import encoding
from esm.utils.constants.models import ESM3_OPEN_SMALL
from esm.utils.misc import get_chainbreak_boundaries_from_sequence
from esm.utils.misc import (
get_chainbreak_boundaries_from_sequence,
)
from esm.utils.structure.protein_chain import ProteinChain
from esm.utils.structure.protein_complex import SINGLE_LETTER_CHAIN_IDS, ProteinComplex
from esm.utils.structure.protein_complex import (
SINGLE_LETTER_CHAIN_IDS,
ProteinComplex,
)
from esm.utils.types import FunctionAnnotation, PathOrBuffer
@@ -35,6 +43,7 @@ class ESMProtein(ProteinType):
plddt: torch.Tensor | None = None
ptm: torch.Tensor | None = None
# When calling EvolutionaryScale API, use this flag to disclose any
# sequences that may potentially have concerns.
# Such sequences may not go through standard safety filter for approved users.
@@ -148,12 +157,35 @@ class ESMProtein(ProteinType):
gt_chains = list(copy_annotations_from_ground_truth.chain_iter())
else:
gt_chains = None
# Expand pLDDT to match sequence length if needed, inserting NaN at chain breaks
# This handles the case where the server doesn't include chain breaks in pLDDT
# We should fix this in the server side.
if self.plddt is not None and len(self.plddt) != len(self.sequence):
# Only expand if there's a mismatch (likely due to chain breaks)
if "|" in self.sequence:
# Create expanded pLDDT with NaN at chain break positions
expanded_plddt = torch.full((len(self.sequence),), float("nan"))
plddt_idx = 0
for i, aa in enumerate(self.sequence):
if aa != "|":
if plddt_idx < len(self.plddt):
expanded_plddt[i] = self.plddt[plddt_idx]
plddt_idx += 1
plddt = expanded_plddt
else:
# Mismatch but no chain breaks - shouldn't happen but preserve original
plddt = self.plddt
else:
plddt = self.plddt
pred_chains = []
for i, (start, end) in enumerate(chain_boundaries):
if i >= len(SINGLE_LETTER_CHAIN_IDS):
raise ValueError(
f"Too many chains to convert to ProteinComplex. The maximum number of chains is {len(SINGLE_LETTER_CHAIN_IDS)}"
)
pred_chain = ProteinChain.from_atom37(
atom37_positions=coords[start:end],
sequence=self.sequence[start:end],
@@ -161,7 +193,7 @@ class ESMProtein(ProteinType):
if gt_chains is not None
else SINGLE_LETTER_CHAIN_IDS[i],
entity_id=gt_chains[i].entity_id if gt_chains is not None else None,
confidence=self.plddt[start:end] if self.plddt is not None else None,
confidence=plddt[start:end] if plddt is not None else None,
)
pred_chains.append(pred_chain)
return ProteinComplex.from_chains(pred_chains)
@@ -298,19 +330,14 @@ class GenerationConfig:
self.temperature_annealing = True
@define
class MSA:
# Paired MSA sequences.
# One would typically compute these using, for example, ColabFold.
sequences: list[str]
@define
class InverseFoldingConfig:
invalid_ids: Sequence[int] = []
temperature: float = 1.0
## Low Level Endpoint Types
@define
class SamplingTrackConfig:
@@ -375,6 +402,9 @@ class LogitsConfig:
ith_hidden_layer: int = -1
@define
class LogitsOutput:
logits: ForwardTrackData | None = None

View File

@@ -1,9 +1,13 @@
import asyncio
import time
from abc import ABC, abstractmethod
from typing import Any
from urllib.parse import urljoin
import httpx
from esm.sdk.api import ESMProteinError
from esm.sdk.retry import retry_decorator
from esm.utils.decoding import assemble_message
@@ -80,6 +84,10 @@ class _BaseForgeInferenceClient:
headers = {**self.headers, **headers}
if return_bytes:
headers["return-bytes"] = "true"
# __INTERNAL_BEGIN___
if disable_cache:
headers["X-Disable-Cache"] = "true"
# __INTERNAL_END___
return request, headers
def prepare_data(self, response, endpoint: str) -> dict[str, Any]:
@@ -112,7 +120,11 @@ class _BaseForgeInferenceClient:
):
try:
request, headers = self.prepare_request(
request, potential_sequence_of_concern, return_bytes, headers
request,
potential_sequence_of_concern,
return_bytes,
disable_cache,
headers,
)
response = await self.async_client.post(
url=urljoin(self.url, f"/api/v1/{endpoint}"),
@@ -142,7 +154,10 @@ class _BaseForgeInferenceClient:
):
try:
request, headers = self.prepare_request(
request, potential_sequence_of_concern, return_bytes, headers
request,
potential_sequence_of_concern,
return_bytes,
headers,
)
response = self.client.post(
url=urljoin(self.url, f"/api/v1/{endpoint}"),
@@ -160,3 +175,5 @@ class _BaseForgeInferenceClient:
error_code=500,
error_msg=f"Failed to submit request to {endpoint}. Error: {e}",
)

View File

@@ -1,13 +1,14 @@
from __future__ import annotations
import asyncio
import base64
import pickle
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Sequence
from typing import Any, Literal, Sequence, cast
import torch
from esm.sdk.api import (
MSA,
ESM3InferenceClient,
ESMCInferenceClient,
ESMProtein,
@@ -19,14 +20,30 @@ from esm.sdk.api import (
InverseFoldingConfig,
LogitsConfig,
LogitsOutput,
ProteinChain,
ProteinType,
SamplingConfig,
SamplingTrackConfig,
)
from esm.sdk.base_forge_client import _BaseForgeInferenceClient
from esm.sdk.base_forge_client import (
_BaseForgeInferenceClient,
)
from esm.sdk.retry import retry_decorator
from esm.utils.constants.api import MIMETYPE_ES_PICKLE
from esm.utils.misc import deserialize_tensors, maybe_list, maybe_tensor
from esm.utils.misc import (
deserialize_tensors,
maybe_list,
maybe_tensor,
)
from esm.utils.msa import MSA
from esm.utils.structure.input_builder import (
StructurePredictionInput,
serialize_structure_prediction_input,
)
from esm.utils.structure.molecular_complex import (
MolecularComplex,
MolecularComplexResult,
)
from esm.utils.types import FunctionAnnotation
@@ -36,10 +53,8 @@ def _list_to_function_annotations(l) -> list[FunctionAnnotation] | None:
return [FunctionAnnotation(*t) for t in l]
def _maybe_logits(data: dict[str, Any], track: str, return_bytes: bool = False):
ret = data.get("logits", {}).get(track, None)
# TODO(s22chan): just return this when removing return_bytes
return ret if ret is None or not return_bytes else maybe_tensor(ret)
def _maybe_logits(data: dict[str, Any], track: str):
return maybe_tensor(data.get("logits", {}).get(track, None))
def _maybe_b64_decode(obj, return_bytes: bool):
@@ -93,9 +108,13 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
)
@staticmethod
def _process_fold_request(sequence: str, model_name: str | None):
def _process_fold_request(
sequence: str,
model_name: str | None,
):
request: dict[str, Any] = {"sequence": sequence}
request["model"] = model_name
return request
@@ -130,6 +149,7 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
return request
async def _async_fetch_msa(self, sequence: str) -> MSA:
print("Fetching MSA ... this may take a few minutes")
# Accept both "|" and ":" as the chainbreak token.
@@ -137,7 +157,7 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
data = await self._async_post(
"msa", request={}, params={"sequence": sequence, "use_env": False}
)
return MSA(sequences=data["msa"])
return MSA.from_sequences(sequences=data["msa"])
def _fetch_msa(self, sequence: str) -> MSA:
print("Fetching MSA ... this may take a few minutes")
@@ -146,7 +166,7 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
data = self._post(
"msa", request={}, params={"sequence": sequence, "use_env": False}
)
return MSA(sequences=data["msa"])
return MSA.from_sequences(sequences=data["msa"])
@retry_decorator
async def async_fold(
@@ -168,11 +188,15 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
del potential_sequence_of_concern
request = self._process_fold_request(
sequence, model_name if model_name is not None else self.model
sequence,
model_name if model_name is not None else self.model,
)
try:
data = await self._async_post("fold", request)
data = await self._async_post(
"fold",
request,
)
except ESMProteinError as e:
return e
@@ -199,16 +223,98 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
del potential_sequence_of_concern
request = self._process_fold_request(
sequence, model_name if model_name is not None else self.model
sequence,
model_name if model_name is not None else self.model,
)
try:
data = self._post("fold", request)
data = self._post(
"fold",
request,
)
except ESMProteinError as e:
return e
return self._process_fold_response(data, sequence)
@retry_decorator
async def async_fold_all_atom(
self,
all_atom_input: StructurePredictionInput,
model_name: str | None = None,
) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError:
"""Fold a molecular complex containing proteins, nucleic acids, and/or ligands.
Args:
all_atom_input: StructurePredictionInput containing sequences for different molecule types
model_name: Override the client level model name if needed
"""
request = self._process_fold_all_atom_request(
all_atom_input,
model_name if model_name is not None else self.model,
)
try:
data = await self._async_post(
"fold_all_atom",
request,
)
except ESMProteinError as e:
return e
return self._process_fold_all_atom_response(data)
@retry_decorator
def fold_all_atom(
self,
all_atom_input: StructurePredictionInput,
model_name: str | None = None,
) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError:
"""Predict coordinates for a molecular complex containing proteins, dna, rna, and/or ligands.
Args:
all_atom_input: StructurePredictionInput containing sequences for different molecule types
model_name: Override the client level model name if needed
"""
request = self._process_fold_all_atom_request(
all_atom_input,
model_name if model_name is not None else self.model,
)
try:
data = self._post(
"fold_all_atom",
request,
)
except ESMProteinError as e:
return e
return self._process_fold_all_atom_response(data)
@staticmethod
def _process_fold_all_atom_request(
all_atom_input: StructurePredictionInput,
model_name: str | None = None,
) -> dict[str, Any]:
request: dict[str, Any] = {
"all_atom_input": serialize_structure_prediction_input(all_atom_input),
"model": model_name,
}
return request
@staticmethod
def _process_fold_all_atom_response(data: dict[str, Any]) -> MolecularComplexResult:
complex_data = data.get("complex")
molecular_complex = MolecularComplex.from_state_dict(complex_data)
return MolecularComplexResult(
complex=molecular_complex,
plddt=maybe_tensor(data.get("plddt"), convert_none_to_nan=True),
ptm=data.get("ptm", None),
distogram=maybe_tensor(data.get("distogram"), convert_none_to_nan=True),
)
@retry_decorator
async def async_inverse_fold(
self,
@@ -280,6 +386,7 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
return ESMProtein(sequence=data["sequence"])
class ESM3ForgeInferenceClient(ESM3InferenceClient, _BaseForgeInferenceClient):
def __init__(
self,
@@ -602,19 +709,15 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient, _BaseForgeInferenceClient):
return LogitsOutput(
logits=ForwardTrackData(
sequence=_maybe_logits(data, "sequence", return_bytes),
structure=_maybe_logits(data, "structure", return_bytes),
secondary_structure=_maybe_logits(
data, "secondary_structure", return_bytes
),
sasa=_maybe_logits(data, "sasa", return_bytes),
function=_maybe_logits(data, "function", return_bytes),
sequence=_maybe_logits(data, "sequence"),
structure=_maybe_logits(data, "structure"),
secondary_structure=_maybe_logits(data, "secondary_structure"),
sasa=_maybe_logits(data, "sasa"),
function=_maybe_logits(data, "function"),
),
embeddings=maybe_tensor(data["embeddings"]),
mean_embedding=data["mean_embedding"],
residue_annotation_logits=_maybe_logits(
data, "residue_annotation", return_bytes
),
residue_annotation_logits=_maybe_logits(data, "residue_annotation"),
hidden_states=maybe_tensor(data["hidden_states"]),
mean_hidden_state=maybe_tensor(data["mean_hidden_state"]),
)
@@ -965,6 +1068,7 @@ class ESMCForgeInferenceClient(ESMCInferenceClient, _BaseForgeInferenceClient):
"sequence": config.sequence,
"return_embeddings": config.return_embeddings,
"return_mean_embedding": config.return_mean_embedding,
"return_mean_hidden_states": config.return_mean_hidden_states,
"return_hidden_states": config.return_hidden_states,
"ith_hidden_layer": config.ith_hidden_layer,
}
@@ -981,12 +1085,11 @@ class ESMCForgeInferenceClient(ESMCInferenceClient, _BaseForgeInferenceClient):
data["hidden_states"] = _maybe_b64_decode(data["hidden_states"], return_bytes)
output = LogitsOutput(
logits=ForwardTrackData(
sequence=_maybe_logits(data, "sequence", return_bytes)
),
logits=ForwardTrackData(sequence=_maybe_logits(data, "sequence")),
embeddings=maybe_tensor(data["embeddings"]),
mean_embedding=data["mean_embedding"],
hidden_states=maybe_tensor(data["hidden_states"]),
mean_hidden_state=maybe_tensor(data["mean_hidden_state"]),
)
return output
@@ -1109,3 +1212,5 @@ class ESMCForgeInferenceClient(ESMCInferenceClient, _BaseForgeInferenceClient):
raise NotImplementedError(
f"Can not get underlying remote model {self.model} from a Forge client."
)

View File

@@ -2,10 +2,9 @@ import inspect
from contextvars import ContextVar
from functools import wraps
import httpx
from tenacity import (
retry,
retry_if_exception_type,
retry_if_exception,
retry_if_result,
stop_after_attempt,
wait_incrementing,
@@ -30,8 +29,12 @@ def retry_if_specific_error(exception):
def log_retry_attempt(retry_state):
try:
outcome = retry_state.outcome.result()
except Exception:
outcome = retry_state.outcome.exception()
print(
f"Retrying... Attempt {retry_state.attempt_number} after {retry_state.next_action.sleep}s due to: {retry_state.outcome.result()}"
f"Retrying... Attempt {retry_state.attempt_number} after {retry_state.next_action.sleep}s due to: {outcome}"
)
@@ -41,13 +44,18 @@ def retry_decorator(func):
instance's retry settings.
"""
def return_last_value(retry_state):
"""Return the result of the last call attempt."""
return retry_state.outcome.result()
@wraps(func)
async def async_wrapper(instance, *args, **kwargs):
if skip_retries_var.get():
return await func(instance, *args, **kwargs)
retry_decorator = retry(
retry_error_callback=return_last_value,
retry=retry_if_result(retry_if_specific_error)
| retry_if_exception_type(httpx.ConnectTimeout), # ADDED
| retry_if_exception(retry_if_specific_error),
wait=wait_incrementing(
increment=1, start=instance.min_retry_wait, max=instance.max_retry_wait
),
@@ -62,8 +70,9 @@ def retry_decorator(func):
if skip_retries_var.get():
return func(instance, *args, **kwargs)
retry_decorator = retry(
retry_error_callback=return_last_value,
retry=retry_if_result(retry_if_specific_error)
| retry_if_exception_type(httpx.ConnectTimeout), # ADDED
| retry_if_exception(retry_if_specific_error),
wait=wait_incrementing(
increment=1, start=instance.min_retry_wait, max=instance.max_retry_wait
),

View File

@@ -1,7 +1,10 @@
from dataclasses import dataclass
from typing import Protocol
from esm.utils.constants.models import ESM3_OPEN_SMALL, normalize_model_name
from esm.utils.constants.models import (
ESM3_OPEN_SMALL,
normalize_model_name,
)
from .function_tokenizer import InterProQuantizedTokenizer
from .residue_tokenizer import ResidueAnnotationsTokenizer

View File

@@ -10,12 +10,24 @@ from esm.models.function_decoder import FunctionTokenDecoder
from esm.models.vqvae import StructureTokenDecoder
from esm.sdk.api import ESMProtein, ESMProteinTensor
from esm.tokenization import TokenizerCollectionProtocol
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
from esm.tokenization.residue_tokenizer import ResidueAnnotationsTokenizer
from esm.tokenization.sasa_tokenizer import SASADiscretizingTokenizer
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
from esm.tokenization.ss_tokenizer import SecondaryStructureTokenizer
from esm.tokenization.structure_tokenizer import StructureTokenizer
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer,
)
from esm.tokenization.residue_tokenizer import (
ResidueAnnotationsTokenizer,
)
from esm.tokenization.sasa_tokenizer import (
SASADiscretizingTokenizer,
)
from esm.tokenization.sequence_tokenizer import (
EsmSequenceTokenizer,
)
from esm.tokenization.ss_tokenizer import (
SecondaryStructureTokenizer,
)
from esm.tokenization.structure_tokenizer import (
StructureTokenizer,
)
from esm.tokenization.tokenizer_base import EsmTokenizerBase
from esm.utils.constants import api as api_constants
from esm.utils.constants import esm3 as C

View File

@@ -7,13 +7,26 @@ from esm.models.vqvae import StructureTokenEncoder
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer as EsmFunctionTokenizer,
)
from esm.tokenization.residue_tokenizer import ResidueAnnotationsTokenizer
from esm.tokenization.sasa_tokenizer import SASADiscretizingTokenizer
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
from esm.tokenization.ss_tokenizer import SecondaryStructureTokenizer
from esm.tokenization.structure_tokenizer import StructureTokenizer
from esm.tokenization.residue_tokenizer import (
ResidueAnnotationsTokenizer,
)
from esm.tokenization.sasa_tokenizer import (
SASADiscretizingTokenizer,
)
from esm.tokenization.sequence_tokenizer import (
EsmSequenceTokenizer,
)
from esm.tokenization.ss_tokenizer import (
SecondaryStructureTokenizer,
)
from esm.tokenization.structure_tokenizer import (
StructureTokenizer,
)
from esm.utils.constants import esm3 as C
from esm.utils.function.encode_decode import encode_function_annotations
from esm.utils.function.encode_decode import (
encode_function_annotations,
)
from esm.utils.structure.protein_chain import ProteinChain
from esm.utils.types import FunctionAnnotation
@@ -152,6 +165,8 @@ def tokenize_function_annotations(
return function_tokens, residue_annotation_tokens
# Tokenized Defaults
def get_default_sequence_tokens(
sequence_length: int, sequence_tokenizer: EsmSequenceTokenizer
@@ -227,3 +242,5 @@ def get_default_residue_annotation_tokens(
residue_annotation_tokens[0] = residue_annotation_tokenizer.bos_token_id
residue_annotation_tokens[-1] = residue_annotation_tokenizer.eos_token_id
return residue_annotation_tokens

View File

@@ -7,7 +7,10 @@ from typing import Any, Callable, List
from tqdm import tqdm
from esm.sdk.api import ESMProteinError
from esm.sdk.retry import retry_if_specific_error, skip_retries_var
from esm.sdk.retry import (
retry_if_specific_error,
skip_retries_var,
)
TQDM_BAR_FORMAT = (
"{desc:<12}{percentage:3.0f}%|{bar:24}| {n_fmt}/{total_fmt} "

View File

@@ -3,9 +3,16 @@ from typing import Sequence
import torch
from esm.models.function_decoder import FunctionTokenDecoder, merge_annotations
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
from esm.tokenization.residue_tokenizer import ResidueAnnotationsTokenizer
from esm.models.function_decoder import (
FunctionTokenDecoder,
merge_annotations,
)
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer,
)
from esm.tokenization.residue_tokenizer import (
ResidueAnnotationsTokenizer,
)
from esm.utils.constants import esm3 as C
from esm.utils.types import FunctionAnnotation

View File

@@ -19,8 +19,13 @@ from esm.sdk.api import (
SamplingConfig,
SamplingTrackConfig,
)
from esm.tokenization import EsmTokenizerBase, TokenizerCollectionProtocol
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
from esm.tokenization import (
EsmTokenizerBase,
TokenizerCollectionProtocol,
)
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer,
)
from esm.utils.constants import esm3 as C
from esm.utils.misc import stack_variable_length_tensors
from esm.utils.noise_schedules import NOISE_SCHEDULE_REGISTRY
@@ -43,7 +48,9 @@ def _trim_sequence_tensor_dataclass(o: Any, sequence_len: int):
sliced = {}
for k, v in attr.asdict(o, recurse=False).items():
if v is None:
if k in ["mean_hidden_state", "mean_embedding"]:
sliced[k] = v
elif v is None:
sliced[k] = None
elif isinstance(v, torch.Tensor):
# Trim padding.

View File

@@ -1,7 +1,19 @@
from __future__ import annotations
import os
from collections import defaultdict
from dataclasses import is_dataclass
from io import BytesIO
from typing import Any, ContextManager, Sequence, TypeVar
from typing import (
Any,
ContextManager,
Generator,
Iterable,
Protocol,
Sequence,
TypeVar,
runtime_checkable,
)
from warnings import warn
import huggingface_hub
@@ -18,6 +30,12 @@ MAX_SUPPORTED_DISTANCE = 1e6
TSequence = TypeVar("TSequence", bound=Sequence)
@runtime_checkable
class Concatable(Protocol):
@classmethod
def concat(cls, objs: list[Concatable]) -> Concatable: ...
def slice_python_object_as_numpy(
obj: TSequence, idx: int | list[int] | slice | np.ndarray
) -> TSequence:
@@ -52,6 +70,37 @@ def slice_python_object_as_numpy(
return sliced_obj # type: ignore
def slice_any_object(
obj: TSequence, idx: int | list[int] | slice | np.ndarray
) -> TSequence:
"""
Slice a arbitrary object (like a list, string, or tuple) as if it was a numpy object. Similar to `slice_python_object_as_numpy`, but detects if it's a numpy array or Tensor and uses the existing slice method if so.
If the object is a dataclass, it will simply apply the index to the object, under the assumption that the object has correcty implemented numpy indexing.
Example:
>>> obj = "ABCDE"
>>> slice_any_object(obj, [1, 3, 4])
"BDE"
>>> obj = np.array([1, 2, 3, 4, 5])
>>> slice_any_object(obj, np.arange(5) < 3)
np.array([1, 2, 3])
>>> obj = ProteinChain.from_rcsb("1a3a", "A")
>>> slice_any_object(obj, np.arange(len(obj)) < 10)
# ProteinChain w/ length 10
"""
if isinstance(obj, (np.ndarray, torch.Tensor)):
return obj[idx] # type: ignore
elif is_dataclass(obj):
# if passing a dataclass, assume it implements a custom slice
return obj[idx] # type: ignore
else:
return slice_python_object_as_numpy(obj, idx)
def rbf(values, v_min, v_max, n_bins=16):
"""
Returns RBF encodings in a new dimension at the end.
@@ -298,6 +347,8 @@ def replace_inf(data):
def maybe_tensor(x, convert_none_to_nan: bool = False) -> torch.Tensor | None:
if x is None:
return None
if isinstance(x, torch.Tensor):
return x
if isinstance(x, list) and all(isinstance(t, torch.Tensor) for t in x):
return torch.stack(x)
if convert_none_to_nan:
@@ -357,3 +408,90 @@ def deserialize_tensors(b: bytes) -> Any:
buf = BytesIO(zstd.ZSTD_uncompress(b))
d = torch.load(buf, map_location="cpu", weights_only=False)
return d
def join_lists(
lists: Sequence[Sequence[Any]], separator: Sequence[Any] | None = None
) -> list[Any]:
"""Joins multiple lists with separator element. Like str.join but for lists.
Example: [[1, 2], [3], [4]], separator=[0] -> [1, 2, 0, 3, 0, 4]
Args:
lists: Lists of elements to chain
separator: separators to intsert between chained output.
Returns:
Joined lists.
"""
if not lists:
return []
joined = []
joined.extend(lists[0])
for l in lists[1:]:
if separator:
joined.extend(separator)
joined.extend(l)
return joined
def iterate_with_intermediate(
lists: Iterable, intermediate
) -> Generator[Any, None, None]:
"""
Iterate over the iterable, yielding the intermediate value between
every element of the intermediate. Useful for joining objects with
separator tokens.
"""
it = iter(lists)
yield next(it)
for l in it:
yield intermediate
yield l
def concat_objects(objs: Sequence[Any], separator: Any | None = None):
"""
Concat objects with each other using a separator token.
Supports:
- Concatable (objects that implement `concat` classmethod)
- strings
- lists
- numpy arrays
- torch Tensors
Example:
>>> foo = "abc"
>>> bar = "def"
>>> concat_objects([foo, bar], "|")
"abc|def"
"""
match objs[0]:
case Concatable():
return objs[0].__class__.concat(objs) # type: ignore
case str():
assert isinstance(
separator, str
), "Trying to join strings but separator is not a string"
return separator.join(objs)
case list():
if separator is not None:
return join_lists(objs, [separator])
else:
return join_lists(objs)
case np.ndarray():
if separator is not None:
return np.concatenate(
list(iterate_with_intermediate(objs, np.array([separator])))
)
else:
return np.concatenate(objs)
case torch.Tensor():
if separator is not None:
return torch.cat(
list(iterate_with_intermediate(objs, torch.tensor([separator])))
)
else:
return torch.cat(objs) # type: ignore
case _:
raise TypeError(type(objs[0]))

View File

@@ -0,0 +1,7 @@
from esm.utils.msa.msa import (
MSA,
FastMSA,
remove_insertions_from_sequence,
)
__all__ = ["MSA", "FastMSA", "remove_insertions_from_sequence"]

View File

@@ -0,0 +1,79 @@
import tempfile
from pathlib import Path
import numpy as np
from scipy.spatial.distance import cdist
from esm.utils.system import run_subprocess_with_errorcheck
def greedy_select_indices(array, num_seqs: int, mode: str = "max") -> list[int]:
"""Greedily select sequences that either maximize or minimize hamming distance.
Algorithm proposed in the MSA Transformer paper. Starting from the query sequence,
iteratively add sequences to the list with the maximum (minimum) average Hamming
distance to the existing set of sequences.
Args:
array (np.ndarray): Character array representing the sequences in the MSA
num_seqs (int): Number of sequences to select.
mode (str): Whether to maximize or minimize diversity. DO NOT pick 'min' unless
you're doing it to prove a point for a paper.
Returns:
list[int]: List of indices to select from the array
"""
assert mode in ("max", "min")
depth = array.shape[0]
if depth <= num_seqs:
return list(range(depth))
array = array.view(np.uint8)
optfunc = np.argmax if mode == "max" else np.argmin
all_indices = np.arange(depth)
indices = [0]
pairwise_distances = np.zeros((0, depth))
for _ in range(num_seqs - 1):
dist = cdist(array[indices[-1:]], array, "hamming")
pairwise_distances = np.concatenate([pairwise_distances, dist])
shifted_distance = np.delete(pairwise_distances, indices, axis=1).mean(0)
shifted_index = optfunc(shifted_distance)
index = np.delete(all_indices, indices)[shifted_index]
indices.append(index)
indices = sorted(indices)
return indices
def hhfilter(
sequences: list[str],
seqid: int = 90,
diff: int = 0,
cov: int = 0,
qid: int = 0,
qsc: float = -20.0,
binary: str = "hhfilter",
) -> list[int]:
with tempfile.TemporaryDirectory(dir="/dev/shm") as tempdirname:
tempdir = Path(tempdirname)
fasta_file = tempdir / "input.fasta"
fasta_file.write_text(
"\n".join(f">{i}\n{seq}" for i, seq in enumerate(sequences))
)
output_file = tempdir / "output.fasta"
command = " ".join(
[
f"{binary}",
f"-i {fasta_file}",
"-M a3m",
f"-o {output_file}",
f"-id {seqid}",
f"-diff {diff}",
f"-cov {cov}",
f"-qid {qid}",
f"-qsc {qsc}",
]
).split(" ")
run_subprocess_with_errorcheck(command, capture_output=True)
with output_file.open() as f:
indices = [int(line[1:].strip()) for line in f if line.startswith(">")]
return indices

507
esm/utils/msa/msa.py Normal file
View File

@@ -0,0 +1,507 @@
from __future__ import annotations
import dataclasses
import string
from dataclasses import dataclass
from functools import cached_property
from itertools import islice
from typing import Sequence
import numpy as np
from Bio import SeqIO
from scipy.spatial.distance import cdist
from esm.utils.misc import slice_any_object
from esm.utils.msa.filter_sequences import (
greedy_select_indices,
hhfilter,
)
from esm.utils.parsing import (
FastaEntry,
read_sequences,
write_sequences,
)
from esm.utils.sequential_dataclass import SequentialDataclass
from esm.utils.system import PathOrBuffer
REMOVE_LOWERCASE_TRANSLATION = str.maketrans(dict.fromkeys(string.ascii_lowercase))
def remove_insertions_from_sequence(seq: str) -> str:
return seq.translate(REMOVE_LOWERCASE_TRANSLATION)
@dataclass(frozen=True)
class MSA(SequentialDataclass):
"""Object-oriented interface to an MSA.
Args:
sequences (list[str]): List of protein sequences
headers (list[str]): List of headers describing the sequences
"""
entries: list[FastaEntry]
@cached_property
def sequences(self) -> list[str]:
return [entry.sequence for entry in self.entries]
@cached_property
def headers(self) -> list[str]:
return [entry.header for entry in self.entries]
def __repr__(self):
return (
f"MSA({self.entries[0].header}: Depth={self.depth}, Length={self.seqlen})"
)
def to_fast_msa(self) -> FastMSA:
return FastMSA(self.array, self.headers)
@classmethod
def from_a3m(
cls,
path: PathOrBuffer,
remove_insertions: bool = True,
max_sequences: int | None = None,
) -> MSA:
entries = []
for header, seq in islice(read_sequences(path), max_sequences):
if remove_insertions:
seq = remove_insertions_from_sequence(seq)
if entries:
assert (
len(seq) == len(entries[0].sequence)
), f"Sequence length mismatch. Expected: {len(entries[0].sequence)}, Received: {len(seq)}"
entries.append(FastaEntry(header, seq))
return cls(entries)
def to_a3m(self, path: PathOrBuffer) -> None:
write_sequences(self.entries, path)
@classmethod
def from_stockholm(
cls,
path: PathOrBuffer,
remove_insertions: bool = True,
max_sequences: int | None = None,
) -> MSA:
entries = []
for record in islice(SeqIO.parse(path, "stockholm"), max_sequences):
header = f"{record.id} {record.description}"
seq = str(record.seq)
if entries:
assert (
len(seq) == len(entries[0].sequence)
), f"Sequence length mismatch. Expected: {len(entries[0].sequence)}, Received: {len(seq)}"
entries.append(FastaEntry(header, seq))
msa = cls(entries)
if remove_insertions:
keep_inds = [i for i, aa in enumerate(msa.query) if aa != "-"]
msa = msa.select_positions(keep_inds)
return msa
def to_bytes(self) -> bytes:
version = 1
version_bytes = version.to_bytes(1, "little")
seqlen_bytes = self.seqlen.to_bytes(4, "little")
depth_bytes = self.depth.to_bytes(4, "little")
array_bytes = self.array.tobytes()
header_bytes = "\n".join(entry.header for entry in self.entries).encode()
all_bytes = (
version_bytes + seqlen_bytes + depth_bytes + array_bytes + header_bytes
)
return all_bytes
@classmethod
def from_bytes(cls, data: bytes) -> MSA:
version_bytes, seqlen_bytes, depth_bytes, data = (
data[:1],
data[1:5],
data[5:9],
data[9:],
)
version = int.from_bytes(version_bytes, "little")
if version != 1:
raise ValueError(f"Unsupported version: {version}")
seqlen = int.from_bytes(seqlen_bytes, "little")
depth = int.from_bytes(depth_bytes, "little")
array_bytes, header_bytes = data[: seqlen * depth], data[seqlen * depth :]
array = np.frombuffer(array_bytes, dtype="|S1")
array = array.reshape(depth, seqlen)
headers = header_bytes.decode().split("\n")
# Sometimes the separation is two newlines, which results in an empty header.
headers = [header for header in headers if header]
entries = [
FastaEntry(header, b"".join(row).decode())
for header, row in zip(headers, array)
]
return cls(entries)
# TODO(jmaccarl): set remove_insertions to True by default here to match other utils
@classmethod
def from_sequences(
cls, sequences: list[str], remove_insertions: bool = False
) -> MSA:
if remove_insertions:
entries = [
FastaEntry("", remove_insertions_from_sequence(seq))
for seq in sequences
]
else:
entries = [FastaEntry("", seq) for seq in sequences]
return cls(entries)
def to_sequence_bytes(self) -> bytes:
"""Stores ONLY SEQUENCES in array format as bytes. Header information will be lost."""
seqlen_bytes = self.seqlen.to_bytes(4, "little")
array_bytes = self.array.tobytes()
all_bytes = seqlen_bytes + array_bytes
return all_bytes
@classmethod
def from_sequence_bytes(cls, data: bytes) -> MSA:
seqlen_bytes, array_bytes = data[:4], data[4:]
seqlen = int.from_bytes(seqlen_bytes, "little")
array = np.frombuffer(array_bytes, dtype="|S1")
array = array.reshape(-1, seqlen)
entries = [FastaEntry("", b"".join(row).decode()) for row in array]
return cls(entries)
@property
def depth(self) -> int:
return len(self.entries)
@property
def seqlen(self) -> int:
return len(self.entries[0].sequence)
@cached_property
def array(self) -> np.ndarray:
return np.array([list(seq) for seq in self.sequences], dtype="|S1")
@property
def query(self) -> str:
return self.entries[0].sequence
def select_sequences(self, indices: Sequence[int] | np.ndarray) -> MSA:
"""Subselect rows of the MSA."""
entries = [self.entries[idx] for idx in indices]
return dataclasses.replace(self, entries=entries)
def select_positions(self, indices: Sequence[int] | np.ndarray) -> MSA:
"""Subselect columns of the MSA."""
entries = [
FastaEntry(header, "".join(seq[idx] for idx in indices))
for header, seq in self.entries
]
return dataclasses.replace(self, entries=entries)
def __getitem__(self, indices: int | list[int] | slice | np.ndarray):
if isinstance(indices, int):
indices = [indices]
entries = [
FastaEntry(header, slice_any_object(seq, indices))
for header, seq in self.entries
]
return dataclasses.replace(self, entries=entries)
def __len__(self):
return self.seqlen
def greedy_select(self, num_seqs: int, mode: str = "max") -> MSA:
"""Greedily select sequences that either maximize or minimize hamming distance.
Algorithm proposed in the MSA Transformer paper. Starting from the query sequence,
iteratively add sequences to the list with the maximum (minimum) average Hamming
distance to the existing set of sequences.
Args:
num_seqs (int): Number of sequences to select.
mode (str): Whether to maximize or minimize diversity. DO NOT pick 'min' unless
you're doing it to prove a point for a paper.
Returns:
MSA object w/ subselected sequences.
"""
assert mode in ("max", "min")
if self.depth <= num_seqs:
return self
indices = greedy_select_indices(self.array, num_seqs, mode)
return self.select_sequences(indices)
def hhfilter(
self,
seqid: int = 90,
diff: int = 0,
cov: int = 0,
qid: int = 0,
qsc: float = -20.0,
binary: str = "hhfilter",
) -> MSA:
"""Apply hhfilter to the sequences in the MSA and return a filtered MSA."""
indices = hhfilter(
self.sequences,
seqid=seqid,
diff=diff,
cov=cov,
qid=qid,
qsc=qsc,
binary=binary,
)
return self.select_sequences(indices)
def select_random_sequences(self, num_seqs: int) -> MSA:
"""Uses random sampling to subselect sequences from the MSA. Always
keeps the query sequence.
"""
if num_seqs >= self.depth:
return self
# Subselect random, always keeping the query sequence.
indices = np.sort(
np.append(
0, np.random.choice(self.depth - 1, num_seqs - 1, replace=False) + 1
)
)
msa = self.select_sequences(indices) # type: ignore
return msa
def select_diverse_sequences(self, num_seqs: int) -> MSA:
"""Applies hhfilter to select ~num_seqs sequences, then uses random sampling
to subselect if necessary.
"""
if num_seqs >= self.depth:
return self
msa = self.hhfilter(diff=num_seqs)
if num_seqs < msa.depth:
msa = msa.select_random_sequences(num_seqs)
return msa
def pad_to_depth(self, depth: int) -> MSA:
if depth < self.depth:
raise ValueError(f"Cannot pad to depth {depth} when depth is {self.depth}")
elif depth == self.depth:
return self
num_to_add = depth - self.depth
extra_entries = [FastaEntry("", "-" * self.seqlen) for _ in range(num_to_add)]
return dataclasses.replace(self, entries=self.entries + extra_entries)
@classmethod
def stack(
cls, msas: Sequence[MSA], remove_query_from_later_msas: bool = True
) -> MSA:
"""Stack a series of MSAs. Optionally remove the query from msas after the first."""
all_entries = []
for i, msa in enumerate(msas):
entries = msa.entries
if i > 0 and remove_query_from_later_msas:
entries = entries[1:]
all_entries.extend(entries)
return cls(entries=all_entries)
@cached_property
def seqid(self) -> np.ndarray:
array = self.array.view(np.uint8)
seqid = 1 - cdist(array[0][None], array, "hamming")
return seqid[0]
@classmethod
def concat(
cls,
msas: Sequence[MSA],
join_token: str | None = "|",
allow_depth_mismatch: bool = False,
) -> MSA:
"""Concatenate a series of MSAs horizontally, along the sequence dimension."""
if not msas:
raise ValueError("Cannot concatenate an empty list of MSAs")
msa_depths = [msa.depth for msa in msas]
if len(set(msa_depths)) != 1:
if not allow_depth_mismatch:
raise ValueError("Depth mismatch in concatenating MSAs")
else:
max_depth = max(msa_depths)
msas = [msa.pad_to_depth(max_depth) for msa in msas]
headers = [
"|".join([str(h) for h in headers])
for headers in zip(*(msa.headers for msa in msas))
]
if join_token is None:
join_token = ""
seqs = [join_token.join(vals) for vals in zip(*(msa.sequences for msa in msas))]
entries = [FastaEntry(header, seq) for header, seq in zip(headers, seqs)]
return cls(entries)
@dataclass(frozen=True)
class FastMSA(SequentialDataclass):
"""Object-oriented interface to an MSA stored as a numpy uint8 array."""
array: np.ndarray
headers: list[str] | None = None
def __post_init__(self):
if self.headers is not None:
assert (
len(self.headers) == self.depth
), "Number of headers must match depth."
@classmethod
def from_bytes(cls, data: bytes) -> FastMSA:
version_bytes, seqlen_bytes, depth_bytes, data = (
data[:1],
data[1:5],
data[5:9],
data[9:],
)
version = int.from_bytes(version_bytes, "little")
if version != 1:
raise ValueError(f"Unsupported version: {version}")
seqlen = int.from_bytes(seqlen_bytes, "little")
depth = int.from_bytes(depth_bytes, "little")
array_bytes, header_bytes = data[: seqlen * depth], data[seqlen * depth :]
array = np.frombuffer(array_bytes, dtype="|S1")
array = array.reshape(depth, seqlen)
headers = header_bytes.decode().split("\n")
# Sometimes the separation is two newlines, which results in an empty header.
headers = [header for header in headers if header]
return cls(array, headers)
@classmethod
def from_sequence_bytes(cls, data: bytes) -> FastMSA:
seqlen_bytes, array_bytes = data[:4], data[4:]
seqlen = int.from_bytes(seqlen_bytes, "little")
array = np.frombuffer(array_bytes, dtype="|S1")
array = array.reshape(-1, seqlen)
return cls(array)
@property
def depth(self) -> int:
return self.array.shape[0]
@property
def seqlen(self) -> int:
return self.array.shape[1]
def __len__(self):
return self.seqlen
def __getitem__(self, indices: int | list[int] | slice | np.ndarray):
if isinstance(indices, int):
indices = [indices]
return dataclasses.replace(self, array=self.array[:, indices])
def select_sequences(self, indices: Sequence[int] | np.ndarray) -> FastMSA:
"""Subselect rows of the MSA."""
array = self.array[indices]
headers = (
[self.headers[idx] for idx in indices] if self.headers is not None else None
)
return dataclasses.replace(self, array=array, headers=headers)
def select_random_sequences(self, num_seqs: int) -> FastMSA:
"""Uses random sampling to subselect sequences from the MSA. Always
keeps the query sequence.
"""
if num_seqs >= self.depth:
return self
# Subselect random, always keeping the query sequence.
indices = np.sort(
np.append(
0, np.random.choice(self.depth - 1, num_seqs - 1, replace=False) + 1
)
)
msa = self.select_sequences(indices) # type: ignore
return msa
def pad_to_depth(self, depth: int) -> FastMSA:
if depth < self.depth:
raise ValueError(f"Cannot pad to depth {depth} when depth is {self.depth}")
elif depth == self.depth:
return self
num_to_add = depth - self.depth
array = np.pad(
self.array,
[(0, num_to_add), (0, 0)],
constant_values=ord("-") if self.array.dtype == np.uint8 else b"-",
)
headers = self.headers
if headers is not None:
headers = headers + [""] * num_to_add
return dataclasses.replace(self, array=array, headers=headers)
@classmethod
def concat(
cls,
msas: Sequence[FastMSA],
join_token: str | None = None,
allow_depth_mismatch: bool = False,
) -> FastMSA:
"""Concatenate a series of MSAs horizontally, along the sequence dimension."""
if not msas:
raise ValueError("Cannot concatenate an empty list of MSAs")
if join_token is not None and join_token != "":
raise NotImplementedError("join_token is not supported for FastMSA")
msa_depths = [msa.depth for msa in msas]
if len(set(msa_depths)) != 1:
if not allow_depth_mismatch:
raise ValueError("Depth mismatch in concatenating MSAs")
else:
max_depth = max(msa_depths)
msas = [msa.pad_to_depth(max_depth) for msa in msas]
headers = [
"|".join([str(h) for h in headers])
for headers in zip(
*(
msa.headers if msa.headers is not None else [""] * msa.depth
for msa in msas
)
)
]
array = np.concatenate([msa.array for msa in msas], axis=1)
return cls(array, headers)
def to_msa(self) -> MSA:
headers = (
self.headers
if self.headers is not None
else [f"seq{i}" for i in range(self.depth)]
)
entries = [
FastaEntry(header, b"".join(row).decode())
for header, row in zip(headers, self.array)
]
return MSA(entries)
@classmethod
def stack(
cls, msas: Sequence[FastMSA], remove_query_from_later_msas: bool = True
) -> FastMSA:
"""Stack a series of MSAs. Optionally remove the query from msas after the first."""
arrays = []
all_headers = []
for i, msa in enumerate(msas):
array = msa.array
headers = msa.headers
if i > 0 and remove_query_from_later_msas:
array = array[1:]
if headers is not None:
headers = headers[1:]
arrays.append(array)
if headers is not None:
all_headers.extend(headers)
return cls(np.concatenate(arrays, axis=0), all_headers)

83
esm/utils/parsing.py Normal file
View File

@@ -0,0 +1,83 @@
import io
from pathlib import Path
from typing import Generator, Iterable, NamedTuple
PathOrBuffer = str | Path | io.TextIOBase
FastaEntry = NamedTuple("FastaEntry", [("header", str), ("sequence", str)])
def parse_fasta(fasta_string: str) -> Generator[FastaEntry, None, None]:
"""
Parses a fasta file and yields FastaEntry objects
Args:
fasta_string: The fasta file as a string
Returns:
A generator of FastaEntry objects
"""
header = None
seq = []
num_sequences = 0
for line in fasta_string.splitlines():
if not line or line[0] == "#":
continue
if line.startswith(">"):
if header is not None:
yield FastaEntry(header, "".join(seq))
seq = []
header = line[1:].strip()
else:
seq.append(line)
if header is not None:
num_sequences += 1
yield FastaEntry(header, "".join(seq))
if num_sequences == 0:
raise ValueError("Found no sequences in input")
def read_sequences(path: PathOrBuffer) -> Generator[FastaEntry, None, None]:
# Uses duck typing to try and call the right method
# Doesn't use explicit isinstance check to support
# inputs that are not explicitly str/Path/TextIOBase but
# may support similar functionality
data = None # type: ignore
try:
if str(path).endswith(".gz"):
import gzip
data = gzip.open(path, "rt") # type: ignore
else:
try:
data = open(path) # type: ignore
except TypeError:
data: io.TextIOBase = path # type: ignore
yield from parse_fasta(data.read())
finally:
if data is not None:
data.close()
def read_first_sequence(path: PathOrBuffer) -> FastaEntry:
return next(iter(read_sequences(path)))
def write_sequences(sequences: Iterable[tuple[str, str]], path: PathOrBuffer) -> None:
needs_closing = False
handle = None
try:
try:
handle = open(path, "w") # type: ignore
needs_closing = True
except TypeError:
handle = path
has_prev = False
for header, seq in sequences:
if has_prev:
handle.write("\n") # type: ignore
handle.write(f">{header}\n{seq}") # type: ignore
has_prev = True
finally:
if needs_closing:
handle.close() # type: ignore

View File

@@ -5,9 +5,18 @@ import attr
import torch
import torch.nn.functional as F
from esm.sdk.api import ESMProteinTensor, SamplingConfig, SamplingTrackConfig
from esm.tokenization import TokenizerCollectionProtocol, get_invalid_tokenizer_ids
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
from esm.sdk.api import (
ESMProteinTensor,
SamplingConfig,
SamplingTrackConfig,
)
from esm.tokenization import (
TokenizerCollectionProtocol,
get_invalid_tokenizer_ids,
)
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer,
)
from esm.utils.constants.esm3 import (
MAX_RESIDUE_ANNOTATIONS,
SASA_DISCRETIZATION_BOUNDARIES,

View File

@@ -0,0 +1,157 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields, replace
from typing import TypeVar
import numpy as np
from esm.utils.misc import concat_objects, slice_any_object
T = TypeVar("T")
@dataclass(frozen=True)
class SequentialDataclass(ABC):
"""
This is a builder on a dataclass that allows for automatic slicing and concatenation.
When representing multimodal data, we often have multiple datatypes which have sequence dimensions that are the same (e.g. the length of the protein).
When applying a transformation like a crop, we want to apply this to all tensors at the same time (e.g. crop the sequence, structure, and function).
We also have some fields that are not sequential (like an id, or data source), which we don't want to crop.
The SequentialDataclass abstracts this cropping away, allowing you to define dataclasses that implement `__len__`, `__getitem__` and `concat` automatically.
This is done through the `metadata` field, which can take 3 values:
`sequence` (bool): True or False, tells the dataclass whether this field is a sequential type. Default: False.
`sequence_dim` (int): Which dimension is the sequential dimension (e.g. for a list of inverse folded sequences, we want to index each sequence in the list, not the list itself). Default: 0.
`join_token` (Any): What token to use to join when concatenating elements. Default: None.
Example:
@dataclass(frozen=True)
class Foo(SequentialDataclass):
id: str
sequence: str = field(metadata={"sequence": True, "join_token": "|"})
tensor: torch.Tensor = field(metadata={"sequence": True, "join_token": torch.nan})
def __len__(self):
# Must implement the __len__ method
return len(self.sequence)
>>> foo = Foo(id="foo", sequence="ABCDE", tensor=torch.randn(5))
Foo(id='foo', sequence='ABCDE', tensor=tensor([ 0.0252, -0.3335, -0.5143, 0.0251, -1.0717]))
>>> foo[1:4]
Foo(id='foo', sequence='BCD', tensor=tensor([-0.3335, -0.5143, 0.0251]))
>>> foo[np.arange(5) < 3]
Foo(id='foo', sequence='ABC', tensor=tensor([ 0.0252, -0.3335, -0.5143]))
>>> Foo.concat([foo[:2], foo[3:]])
Foo(id='foo', sequence='AB|DE', tensor=tensor([ 0.0252, -0.3335, nan, 0.0251, -1.0717]))
# Trying to create a type where the sequence lengths do not match raises an error
>>> foo = Foo(id="foo", sequence="ABCDE", tensor=torch.randn(6))
ValueError: Mismatch in sequence length for field: tensor. Expected 5, received 6
"""
def __post_init__(self):
self._check_sequence_lengths_match()
@abstractmethod
def __len__(self):
raise NotImplementedError
def __getitem__(self, idx: int | list[int] | slice | np.ndarray):
updated_fields = {}
if isinstance(idx, int):
# make it so that things remain sequential
idx = [idx]
for fld in fields(self):
if fld.metadata.get("sequence", False):
# this is a sequence, should be the same length as all other sequences
sequence_dim = fld.metadata.get("sequence_dim", 0)
value = getattr(self, fld.name)
if value is None:
continue
match sequence_dim:
case 0:
# sequence is first dimension
value = getattr(self, fld.name)
value = slice_any_object(value, idx)
updated_fields[fld.name] = value
case 1:
new_value = [slice_any_object(item, idx) for item in value]
updated_fields[fld.name] = value.__class__(new_value)
case _:
raise NotImplementedError(
"Arbitrary slicing for different sequence length fields is not implemented"
)
return replace(self, **updated_fields)
def _check_sequence_lengths_match(self):
"""Checks if sequence lengths of all "sequence" fields match."""
for fld in fields(self):
if fld.metadata.get("sequence", False) and fld.name != "complex":
# this is a sequence, should be the same length as all other sequences
sequence_dim = fld.metadata.get("sequence_dim", 0)
value = getattr(self, fld.name)
if value is None:
continue
match sequence_dim:
case 0:
# sequence is first dimension
value = getattr(self, fld.name)
if len(value) != len(self):
raise ValueError(
f"Mismatch in sequence length for field: {fld.name}. Expected {len(self)}, received {len(value)}"
)
case 1:
for item in value:
if len(item) != len(self):
raise ValueError(
f"Mismatch in sequence length for field: {fld.name}. Expected {len(self)}, received {len(item)}"
)
case _:
raise NotImplementedError(
"Arbitrary matching for different sequence length fields is not implemented"
)
@classmethod
def concat(cls, items: list[T], **kwargs) -> T:
updated_fields = {}
for fld in fields(cls):
if fld.metadata.get("sequence", False):
# this is a sequence, should be the same length as all other sequences
sequence_dim = fld.metadata.get("sequence_dim", 0)
join_value = fld.metadata.get("join_token", None)
if getattr(items[0], fld.name) is None:
continue
values = [getattr(item, fld.name) for item in items]
match sequence_dim:
case 0:
# sequence is first dimension
value = concat_objects(values, join_value)
updated_fields[fld.name] = value
case 1:
new_value = [
concat_objects(item, join_value) for item in zip(*values)
]
updated_fields[fld.name] = getattr(
items[0], fld.name
).__class__(new_value)
case _:
raise NotImplementedError(
"Arbitrary joining for different sequence length fields is not implemented"
)
updated_fields.update(kwargs)
return replace(
items[0], # type: ignore
**updated_fields,
)

View File

@@ -6,7 +6,9 @@ from typing import Any, ClassVar, Protocol, TypeVar
import numpy as np
import torch
from esm.utils.structure.protein_structure import compute_affine_and_rmsd
from esm.utils.structure.protein_structure import (
compute_affine_and_rmsd,
)
class Alignable(Protocol):

View File

@@ -1,6 +1,8 @@
import numpy as np
from esm.utils.structure.protein_structure import index_by_atom_name
from esm.utils.structure.protein_structure import (
index_by_atom_name,
)
class AtomIndexer:

View File

@@ -0,0 +1,96 @@
from dataclasses import dataclass
from typing import Any, Sequence
import numpy as np
@dataclass
class Modification:
position: int # zero-indexed
ccd: str
@dataclass
class ProteinInput:
id: str | list[str]
sequence: str
modifications: list[Modification] | None = None
@dataclass
class RNAInput:
id: str | list[str]
sequence: str
modifications: list[Modification] | None = None
@dataclass
class DNAInput:
id: str | list[str]
sequence: str
modifications: list[Modification] | None = None
@dataclass
class LigandInput:
id: str | list[str]
smiles: str
ccd: list[str] | None = None
@dataclass
class DistogramConditioning:
chain_id: str
distogram: np.ndarray
@dataclass
class PocketConditioning:
binder_chain_id: str
contacts: list[tuple[str, int]]
@dataclass
class StructurePredictionInput:
sequences: Sequence[ProteinInput | RNAInput | DNAInput | LigandInput]
pocket: PocketConditioning | None = None
distogram_conditioning: list[DistogramConditioning] | None = None
def serialize_structure_prediction_input(all_atom_input: StructurePredictionInput):
def create_chain_data(seq_input, chain_type: str) -> dict[str, Any]:
chain_data: dict[str, Any] = {
"sequence": seq_input.sequence,
"id": seq_input.id,
"type": chain_type,
}
if hasattr(seq_input, "modifications") and seq_input.modifications:
mods = [
{"position": mod.position, "ccd": mod.ccd}
for mod in seq_input.modifications
]
chain_data["modifications"] = mods
return chain_data
sequences = []
for seq_input in all_atom_input.sequences:
if isinstance(seq_input, ProteinInput):
sequences.append(create_chain_data(seq_input, "protein"))
elif isinstance(seq_input, RNAInput):
sequences.append(create_chain_data(seq_input, "rna"))
elif isinstance(seq_input, DNAInput):
sequences.append(create_chain_data(seq_input, "dna"))
elif isinstance(seq_input, LigandInput):
sequences.append(
{
"smiles": seq_input.smiles,
"id": seq_input.id,
"ccd": seq_input.ccd,
"type": "ligand",
}
)
else:
raise ValueError(f"Unsupported sequence input type: {type(seq_input)}")
return {"sequences": sequences}

View File

@@ -264,7 +264,7 @@ def compute_lddt_ca(
if all_atom_pred_pos.dim() != 3:
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
all_atom_mask = all_atom_mask[..., ca_pos]
return compute_lddt(
all_atom_pred_pos,

View File

@@ -0,0 +1,944 @@
from __future__ import annotations
import io
import os
import re
from dataclasses import asdict, dataclass
from pathlib import Path
from subprocess import check_output
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any, List
import biotite.structure.io.pdbx as pdbx
import brotli
import msgpack
import numpy as np
import torch
from esm.utils import residue_constants
from esm.utils.structure.metrics import (
compute_lddt,
compute_rmsd,
)
from esm.utils.structure.protein_complex import (
ProteinComplex,
ProteinComplexMetadata,
)
@dataclass
class MolecularComplexResult:
"""Result of molecular complex folding"""
complex: MolecularComplex
plddt: torch.Tensor | None = None
ptm: float | None = None
iptm: float | None = None
pae: torch.Tensor | None = None
distogram: torch.Tensor | None = None
pair_chains_iptm: torch.Tensor | None = None
output_embedding_sequence: torch.Tensor | None = None
output_embedding_pair_pooled: torch.Tensor | None = None
@dataclass
class MolecularComplexMetadata:
"""Metadata for MolecularComplex objects."""
entity_lookup: dict[int, str]
chain_lookup: dict[int, str]
assembly_composition: dict[str, list[str]] | None = None
@dataclass
class Molecule:
"""Represents a single molecule/token within a MolecularComplex."""
token: str
token_idx: int
atom_positions: np.ndarray # [N_atoms, 3]
atom_elements: np.ndarray # [N_atoms] element strings
residue_type: int
molecule_type: int # PROTEIN=0, RNA=1, DNA=2, LIGAND=3
confidence: float
@dataclass(frozen=True)
class MolecularComplex:
"""
Dataclass representing a molecular complex with support for proteins, nucleic acids, and ligands.
Uses a flat atom representation with token-based sequence indexing, supporting all atom types
beyond the traditional atom37 protein representation.
"""
id: str
sequence: List[str] # Token sequence like ['MET', 'LYS', 'A', 'G', 'ATP']
# Flat atom arrays - simplified representation
atom_positions: np.ndarray # [N_atoms, 3] 3D coordinates
atom_elements: np.ndarray # [N_atoms] element strings
# Token-to-atom mapping for efficient access
token_to_atoms: np.ndarray # [N_tokens, 2] start/end indices into atoms array
# Confidence data
plddt: np.ndarray # Per-token confidence scores [N_tokens]
# Metadata
metadata: MolecularComplexMetadata
def __post_init__(self):
"""Validate array dimensions."""
n_tokens = len(self.sequence)
assert (
self.token_to_atoms.shape[0] == n_tokens
), f"token_to_atoms shape {self.token_to_atoms.shape} != {n_tokens} tokens"
assert (
self.plddt.shape[0] == n_tokens
), f"plddt shape {self.plddt.shape} != {n_tokens} tokens"
def __len__(self) -> int:
"""Return number of tokens."""
return len(self.sequence)
def __getitem__(self, idx: int) -> Molecule:
"""Access individual molecules/tokens by index."""
if idx >= len(self.sequence) or idx < 0:
raise IndexError(
f"Token index {idx} out of range for {len(self.sequence)} tokens"
)
token = self.sequence[idx]
start_atom, end_atom = self.token_to_atoms[idx]
# Extract atom data for this token
token_atom_positions = self.atom_positions[start_atom:end_atom]
token_atom_elements = self.atom_elements[start_atom:end_atom]
# Default values for residue/molecule type (would be extended based on actual implementation)
residue_type = 0 # Default to standard residue
molecule_type = 0 # Default to protein
return Molecule(
token=token,
token_idx=idx,
atom_positions=token_atom_positions,
atom_elements=token_atom_elements,
residue_type=residue_type,
molecule_type=molecule_type,
confidence=self.plddt[idx],
)
@property
def atom_coordinates(self) -> np.ndarray:
"""Get flat array of all atom coordinates [N_atoms, 3]."""
return self.atom_positions
# Conversion methods
@classmethod
def from_protein_complex(cls, pc: ProteinComplex) -> "MolecularComplex":
"""Convert a ProteinComplex to MolecularComplex.
Args:
pc: ProteinComplex object with atom37 representation
Returns:
MolecularComplex with flat atom arrays and token-based indexing
"""
from esm.utils import residue_constants
# Extract sequence without chain breaks
sequence_no_breaks = pc.sequence.replace("|", "")
sequence_tokens = [
residue_constants.restype_1to3.get(aa, "UNK") for aa in sequence_no_breaks
]
# Convert atom37 to flat arrays
flat_positions = []
flat_elements = []
token_to_atoms = []
atom_idx = 0
residue_idx = 0
for i, aa in enumerate(pc.sequence):
if aa == "|":
# Skip chain break tokens
continue
# Get atom37 positions and mask for this residue
res_positions = pc.atom37_positions[residue_idx] # [37, 3]
res_mask = pc.atom37_mask[residue_idx] # [37]
# Track start position for this token
token_start = atom_idx
# Process each atom type in atom37 representation
for atom_type_idx, atom_name in enumerate(residue_constants.atom_types):
if res_mask[atom_type_idx]: # Atom is present
# Add position
flat_positions.append(res_positions[atom_type_idx])
# Determine element from atom name
element = (
atom_name[0] if atom_name else "C"
) # First character is element
flat_elements.append(element)
atom_idx += 1
# Record token-to-atom mapping [start_idx, end_idx)
token_to_atoms.append([token_start, atom_idx])
residue_idx += 1
# Convert to numpy arrays
atom_positions = np.array(flat_positions, dtype=np.float32)
atom_elements = np.array(flat_elements, dtype=object)
token_to_atoms_array = np.array(token_to_atoms, dtype=np.int32)
# Extract confidence scores (skip chain breaks)
confidence_scores = []
residue_idx = 0
for aa in pc.sequence:
if aa != "|":
confidence_scores.append(pc.confidence[residue_idx])
residue_idx += 1
confidence_array = np.array(confidence_scores, dtype=np.float32)
# Create metadata - convert entity IDs to strings for MolecularComplexMetadata
entity_lookup_str = {k: str(v) for k, v in pc.metadata.entity_lookup.items()}
metadata = MolecularComplexMetadata(
entity_lookup=entity_lookup_str,
chain_lookup=pc.metadata.chain_lookup,
assembly_composition=pc.metadata.assembly_composition,
)
return cls(
id=pc.id,
sequence=sequence_tokens,
atom_positions=atom_positions,
atom_elements=atom_elements,
token_to_atoms=token_to_atoms_array,
plddt=confidence_array,
metadata=metadata,
)
def to_protein_complex(self) -> ProteinComplex:
"""Convert MolecularComplex back to ProteinComplex format.
Extracts only protein tokens and converts from flat atom representation
back to atom37 format used by ProteinComplex.
Returns:
ProteinComplex with protein residues only, excluding ligands/nucleic acids
"""
from esm.utils import residue_constants
# No need for element mapping - already using element characters
# Filter for protein tokens only (skip ligands, nucleic acids)
protein_tokens = []
protein_indices = []
for i, token in enumerate(self.sequence):
# Check if token is a standard 3-letter amino acid code
if token in residue_constants.restype_3to1:
protein_tokens.append(token)
protein_indices.append(i)
if not protein_tokens:
raise ValueError("No protein tokens found in MolecularComplex")
n_residues = len(protein_tokens)
# Initialize atom37 arrays
atom37_positions = np.full((n_residues, 37, 3), np.nan, dtype=np.float32)
atom37_mask = np.zeros((n_residues, 37), dtype=bool)
# Convert tokens back to single-letter sequence
single_letter_sequence = "".join(
[residue_constants.restype_3to1[token] for token in protein_tokens]
)
# Extract confidence scores for protein residues only
protein_confidence = self.plddt[protein_indices]
# Convert flat atoms back to atom37 representation
for res_idx, token_idx in enumerate(protein_indices):
token = self.sequence[token_idx]
start_atom, end_atom = self.token_to_atoms[token_idx]
# Get atom data for this residue
res_atom_positions = self.atom_positions[start_atom:end_atom]
# Reconstruct atom37 representation by exactly reversing the forward conversion logic
# In from_protein_complex, atoms are added in atom_types order if present in mask
# So we need to reconstruct the mask and positions in the same order
atom_count = 0
for atom_type_idx, atom_name in enumerate(residue_constants.atom_types):
# Check if this atom type exists for this residue and was present
residue_atoms = residue_constants.residue_atoms.get(token, [])
if atom_name in residue_atoms:
# This atom type exists for this residue, so it should have been included
if atom_count < len(res_atom_positions):
atom37_positions[res_idx, atom_type_idx] = res_atom_positions[
atom_count
]
atom37_mask[res_idx, atom_type_idx] = True
atom_count += 1
# Create other required arrays for ProteinComplex
# For simplicity, assume all protein residues belong to the same entity/chain
entity_id = np.zeros(n_residues, dtype=np.int64)
chain_id = np.zeros(n_residues, dtype=np.int64)
sym_id = np.zeros(n_residues, dtype=np.int64)
residue_index = np.arange(1, n_residues + 1, dtype=np.int64)
insertion_code = np.array([""] * n_residues, dtype=object)
# Create simplified protein complex metadata
# Map the first entity/chain from molecular complex metadata
protein_metadata = ProteinComplexMetadata(
entity_lookup={0: 1}, # Single entity (int for ProteinComplexMetadata)
chain_lookup={0: "A"}, # Single chain
assembly_composition=self.metadata.assembly_composition,
)
return ProteinComplex(
id=self.id,
sequence=single_letter_sequence,
entity_id=entity_id,
chain_id=chain_id,
sym_id=sym_id,
residue_index=residue_index,
insertion_code=insertion_code,
atom37_positions=atom37_positions,
atom37_mask=atom37_mask,
confidence=protein_confidence,
metadata=protein_metadata,
)
@classmethod
def from_mmcif(cls, inp: str, id: str | None = None) -> "MolecularComplex":
"""Read MolecularComplex from mmcif file or string.
Args:
inp: Path to mmCIF file or mmCIF content as string
id: Optional identifier to assign to the complex
Returns:
MolecularComplex with all molecules (proteins, ligands, nucleic acids)
"""
from io import StringIO
# Check if input is a file path or mmCIF string content
if os.path.exists(inp):
# Input is a file path
mmcif_file = pdbx.CIFFile.read(inp)
else:
# Input is mmCIF string content
mmcif_file = pdbx.CIFFile.read(StringIO(inp))
# Get structure - handle missing model information gracefully
try:
structure = pdbx.get_structure(mmcif_file, model=1)
except (KeyError, ValueError):
# Fallback for mmCIF files without model information
try:
structure = pdbx.get_structure(mmcif_file)
except Exception:
# Last resort: use the first available model or all atoms
structure = pdbx.get_structure(mmcif_file, model=None)
# Type hint for pyright - structure is an AtomArray which is iterable
if TYPE_CHECKING:
structure: Any = structure
# Get entity information from mmCIF
entity_info = {}
try:
# Access the first block in CIFFile
block = mmcif_file[0]
if "entity" in block:
entity_category = block["entity"]
if "id" in entity_category and "type" in entity_category:
entity_ids = entity_category["id"]
entity_types = entity_category["type"]
# Convert CIFColumn to list for iteration
if hasattr(entity_ids, "__iter__") and hasattr(
entity_types, "__iter__"
):
# Type annotation to help pyright understand these are iterable
entity_ids_list = list(entity_ids) # type: ignore
entity_types_list = list(entity_types) # type: ignore
for eid, etype in zip(entity_ids_list, entity_types_list):
entity_info[eid] = etype
except Exception:
pass
# Initialize arrays for flat atom representation
sequence_tokens = []
flat_positions = []
flat_elements = []
token_to_atoms = []
confidence_scores = []
atom_idx = 0
# Group atoms by chain and residue
chain_residue_groups = {}
for atom in structure:
chain_id = atom.chain_id
res_id = atom.res_id
res_name = atom.res_name
if chain_id not in chain_residue_groups:
chain_residue_groups[chain_id] = {}
if res_id not in chain_residue_groups[chain_id]:
chain_residue_groups[chain_id][res_id] = {
"atoms": [],
"res_name": res_name,
"is_hetero": atom.hetero,
}
chain_residue_groups[chain_id][res_id]["atoms"].append(atom)
# Process each chain and residue
for chain_id in sorted(chain_residue_groups.keys()):
residues = chain_residue_groups[chain_id]
for res_id in sorted(residues.keys()):
residue_data = residues[res_id]
res_name = residue_data["res_name"]
atoms = residue_data["atoms"]
is_hetero = residue_data["is_hetero"]
# Skip water molecules
if res_name == "HOH":
continue
# Determine token name
if not is_hetero and res_name in residue_constants.restype_3to1:
# Standard amino acid
token_name = res_name
elif res_name in ["A", "T", "G", "C", "U", "DA", "DT", "DG", "DC"]:
# Nucleotide
token_name = res_name
else:
# Ligand or other molecule
token_name = res_name
sequence_tokens.append(token_name)
token_start = atom_idx
# Add all atoms from this residue
for atom in atoms:
flat_positions.append(atom.coord)
# Get element character
element = atom.element
flat_elements.append(element)
atom_idx += 1
# Record token-to-atom mapping
token_to_atoms.append([token_start, atom_idx])
# Add confidence score (B-factor if available, otherwise 1.0)
bfactor = getattr(atoms[0], "b_factor", 50.0) if atoms else 50.0
confidence_scores.append(min(bfactor / 100.0, 1.0))
# Convert to numpy arrays
if not flat_positions:
# Create minimal arrays if no atoms found
atom_positions = np.zeros((0, 3), dtype=np.float32)
atom_elements = np.zeros(0, dtype=object)
token_to_atoms_array = np.zeros((len(sequence_tokens), 2), dtype=np.int32)
else:
atom_positions = np.array(flat_positions, dtype=np.float32)
atom_elements = np.array(flat_elements, dtype=object)
token_to_atoms_array = np.array(token_to_atoms, dtype=np.int32)
confidence_array = np.array(confidence_scores, dtype=np.float32)
# Create metadata
metadata = MolecularComplexMetadata(
entity_lookup=entity_info,
chain_lookup={
i: chain_id for i, chain_id in enumerate(chain_residue_groups.keys())
},
assembly_composition=None,
)
# Set complex ID - if input was a path, use the stem; otherwise use default
if os.path.exists(inp):
complex_id = id or Path(inp).stem
else:
complex_id = id or "complex_from_string"
return cls(
id=complex_id,
sequence=sequence_tokens,
atom_positions=atom_positions,
atom_elements=atom_elements,
token_to_atoms=token_to_atoms_array,
plddt=confidence_array,
metadata=metadata,
)
def to_mmcif(self) -> str:
"""Write MolecularComplex to mmcif string.
Returns:
String representation of the complex in mmCIF format
"""
# No need for element mapping - already using element characters
lines = []
# Header
lines.append(f"data_{self.id}")
lines.append("#")
lines.append(f"_entry.id {self.id}")
lines.append("#")
# Structure metadata
lines.append("_struct.entry_id {}".format(self.id))
lines.append("_struct.title 'Protein Structure'")
lines.append("#")
# Entity information
entity_id = 1
chain_counter = 0
lines.append("loop_")
lines.append("_entity.id")
lines.append("_entity.type")
lines.append("_entity.pdbx_description")
# Determine entities based on sequence
protein_tokens = []
other_tokens = []
for i, token in enumerate(self.sequence):
if token in residue_constants.restype_3to1:
protein_tokens.append((i, token))
else:
other_tokens.append((i, token))
if protein_tokens:
lines.append(f"{entity_id} polymer 'Protein chain'")
entity_id += 1
for token in set(token for _, token in other_tokens):
lines.append(f"{entity_id} non-polymer 'Ligand {token}'")
entity_id += 1
lines.append("#")
# Chain assignments
lines.append("loop_")
lines.append("_struct_asym.id")
lines.append("_struct_asym.entity_id")
chain_id = "A"
if protein_tokens:
lines.append(f"{chain_id} 1")
chain_counter += 1
chain_id = chr(ord(chain_id) + 1)
entity_id = 2
for token in set(token for _, token in other_tokens):
lines.append(f"{chain_id} {entity_id}")
entity_id += 1
chain_counter += 1
if chain_counter < 26:
chain_id = chr(ord(chain_id) + 1)
lines.append("#")
# Atom site information
lines.append("loop_")
lines.append("_atom_site.group_PDB")
lines.append("_atom_site.id")
lines.append("_atom_site.type_symbol")
lines.append("_atom_site.label_atom_id")
lines.append("_atom_site.label_alt_id")
lines.append("_atom_site.label_comp_id")
lines.append("_atom_site.label_asym_id")
lines.append("_atom_site.label_entity_id")
lines.append("_atom_site.label_seq_id")
lines.append("_atom_site.pdbx_PDB_ins_code")
lines.append("_atom_site.Cartn_x")
lines.append("_atom_site.Cartn_y")
lines.append("_atom_site.Cartn_z")
lines.append("_atom_site.occupancy")
lines.append("_atom_site.B_iso_or_equiv")
lines.append("_atom_site.pdbx_PDB_model_num")
lines.append("_atom_site.auth_seq_id")
lines.append("_atom_site.auth_comp_id")
lines.append("_atom_site.auth_asym_id")
lines.append("_atom_site.auth_atom_id")
atom_id = 1
seq_id = 1
chain_id = "A"
entity_id = 1
for token_idx, token in enumerate(self.sequence):
start_atom, end_atom = self.token_to_atoms[token_idx]
# Determine if this is a protein residue or ligand
is_protein = token in residue_constants.restype_3to1
group_pdb = "ATOM" if is_protein else "HETATM"
current_entity_id = 1 if is_protein else 2 # Simplified entity assignment
current_chain_id = "A" if is_protein else "B" # Simplified chain assignment
# Create atom names for this token
atom_names = []
if is_protein:
# Use standard protein atom names
res_atoms = residue_constants.residue_atoms.get(
token, ["N", "CA", "C", "O"]
)
atom_names = res_atoms[: end_atom - start_atom]
else:
# Generate generic atom names for ligands
for i in range(end_atom - start_atom):
atom_names.append(f"C{i+1}")
# Pad atom names if needed
while len(atom_names) < (end_atom - start_atom):
atom_names.append(f"X{len(atom_names)+1}")
# Write atoms for this token
for atom_idx_in_token, global_atom_idx in enumerate(
range(start_atom, end_atom)
):
pos = self.atom_positions[global_atom_idx]
element_char = self.atom_elements[global_atom_idx]
element_symbol = element_char if isinstance(element_char, str) else "C"
atom_name = (
atom_names[atom_idx_in_token]
if atom_idx_in_token < len(atom_names)
else f"X{atom_idx_in_token+1}"
)
# Format atom site line
bfactor = (
self.plddt[token_idx] * 100.0
if len(self.plddt) > token_idx
else 50.0
)
line = (
f"{group_pdb:<6} {atom_id:>5} {element_symbol:<2} {atom_name:<4} . "
f"{token:<3} {current_chain_id} {current_entity_id} {seq_id:>3} ? "
f"{pos[0]:>8.3f} {pos[1]:>8.3f} {pos[2]:>8.3f} 1.00 {bfactor:>6.2f} 1 "
f"{seq_id:>3} {token:<3} {current_chain_id} {atom_name:<4}"
)
lines.append(line)
atom_id += 1
seq_id += 1
lines.append("#")
return "\n".join(lines)
def dockq(self, native: "MolecularComplex") -> Any:
"""Compute DockQ score against native structure.
Args:
native: Native MolecularComplex to compute DockQ against
Returns:
DockQ result containing score and alignment information
"""
# Imports moved to top of file
# Convert both complexes to ProteinComplex format for DockQ computation
# This extracts only the protein portion and converts to PDB format
try:
self_pc = self.to_protein_complex()
native_pc = native.to_protein_complex()
except ValueError as e:
raise ValueError(
f"Cannot convert MolecularComplex to ProteinComplex for DockQ: {e}"
)
# Normalize chain IDs for PDB compatibility
self_pc = self_pc.normalize_chain_ids_for_pdb()
native_pc = native_pc.normalize_chain_ids_for_pdb()
# Use the existing ProteinComplex.dockq() method
try:
dockq_result = self_pc.dockq(native_pc)
return dockq_result
except Exception:
# Fallback to manual DockQ computation if ProteinComplex.dockq() fails
return self._compute_dockq_manual(native)
def _compute_dockq_manual(self, native: "MolecularComplex") -> Any:
"""Manual DockQ computation fallback."""
# Imports moved to top of file
# Convert both complexes to ProteinComplex format
try:
self_pc = self.to_protein_complex()
native_pc = native.to_protein_complex()
except ValueError as e:
raise ValueError(
f"Cannot convert MolecularComplex to ProteinComplex for DockQ: {e}"
)
# Normalize chain IDs for PDB compatibility
self_pc = self_pc.normalize_chain_ids_for_pdb()
native_pc = native_pc.normalize_chain_ids_for_pdb()
# Write temporary PDB files and run DockQ
with TemporaryDirectory() as tdir:
dir_path = Path(tdir)
self_pdb = dir_path / "self.pdb"
native_pdb = dir_path / "native.pdb"
# Write PDB files
self_pc.to_pdb(self_pdb)
native_pc.to_pdb(native_pdb)
# Run DockQ
try:
output = check_output(["DockQ", str(self_pdb), str(native_pdb)])
output_text = output.decode()
# Parse DockQ output
lines = output_text.split("\n")
# Find the total DockQ score
dockq_score = None
for line in lines:
if "Total DockQ" in line:
match = re.search(r"Total DockQ.*: ([\d.]+)", line)
if match:
dockq_score = float(match.group(1))
break
if dockq_score is None:
# Try to find individual DockQ scores
for line in lines:
if line.startswith("DockQ") and ":" in line:
try:
dockq_score = float(line.split(":")[1].strip())
break
except (ValueError, IndexError):
continue
if dockq_score is None:
raise ValueError("Could not parse DockQ score from output")
# Return a simple result structure
return {
"total_dockq": dockq_score,
"raw_output": output_text,
"aligned": self, # Return self as aligned structure
}
except FileNotFoundError:
raise RuntimeError(
"DockQ is not installed. Please install DockQ to use this method."
)
except Exception as e:
raise RuntimeError(f"DockQ computation failed: {e}")
def rmsd(self, target: "MolecularComplex", **kwargs) -> float:
"""Compute RMSD against target structure.
Args:
target: Target MolecularComplex to compute RMSD against
**kwargs: Additional arguments passed to compute_rmsd
Returns:
float: RMSD value between the two structures
"""
# Imports moved to top of file
# Ensure both complexes have the same number of tokens
if len(self) != len(target):
raise ValueError(
f"Complexes must have the same number of tokens: {len(self)} vs {len(target)}"
)
# Extract center positions for each token (using centroid of atoms)
mobile_coords = []
target_coords = []
atom_mask = []
for i in range(len(self)):
# Get atom positions for this token
mobile_start, mobile_end = self.token_to_atoms[i]
target_start, target_end = target.token_to_atoms[i]
# Extract atom positions
mobile_atoms = self.atom_positions[mobile_start:mobile_end]
target_atoms = target.atom_positions[target_start:target_end]
# Check if both tokens have atoms
if len(mobile_atoms) == 0 or len(target_atoms) == 0:
# Skip tokens with no atoms
continue
# For simplicity, use the centroid of atoms as the representative position
mobile_center = mobile_atoms.mean(axis=0)
target_center = target_atoms.mean(axis=0)
mobile_coords.append(mobile_center)
target_coords.append(target_center)
atom_mask.append(True)
if len(mobile_coords) == 0:
raise ValueError("No valid atoms found for RMSD computation")
# Convert to tensors
mobile_tensor = torch.from_numpy(np.stack(mobile_coords, axis=0)).unsqueeze(
0
) # [1, N, 3]
target_tensor = torch.from_numpy(np.stack(target_coords, axis=0)).unsqueeze(
0
) # [1, N, 3]
mask_tensor = torch.tensor(atom_mask, dtype=torch.bool).unsqueeze(0) # [1, N]
# Compute RMSD using existing infrastructure
rmsd_value = compute_rmsd(
mobile=mobile_tensor,
target=target_tensor,
atom_exists_mask=mask_tensor,
reduction="batch",
**kwargs,
)
return float(rmsd_value)
def lddt_ca(self, target: "MolecularComplex", **kwargs) -> float:
"""Compute LDDT score against target structure.
Args:
target: Target MolecularComplex to compute LDDT against
**kwargs: Additional arguments passed to compute_lddt
Returns:
float: LDDT value between the two structures
"""
# Imports moved to top of file
# Ensure both complexes have the same number of tokens
if len(self) != len(target):
raise ValueError(
f"Complexes must have the same number of tokens: {len(self)} vs {len(target)}"
)
# Extract center positions for each token (using centroid of atoms)
mobile_coords = []
target_coords = []
atom_mask = []
for i in range(len(self)):
# Get atom positions for this token
mobile_start, mobile_end = self.token_to_atoms[i]
target_start, target_end = target.token_to_atoms[i]
# Extract atom positions
mobile_atoms = self.atom_positions[mobile_start:mobile_end]
target_atoms = target.atom_positions[target_start:target_end]
# Check if both tokens have atoms
if len(mobile_atoms) == 0 or len(target_atoms) == 0:
# Skip tokens with no atoms
mobile_coords.append(np.full(3, np.nan))
target_coords.append(np.full(3, np.nan))
atom_mask.append(False)
continue
# For simplicity, use the centroid of atoms as the representative position
mobile_center = mobile_atoms.mean(axis=0)
target_center = target_atoms.mean(axis=0)
mobile_coords.append(mobile_center)
target_coords.append(target_center)
atom_mask.append(True)
if not any(atom_mask):
raise ValueError("No valid atoms found for LDDT computation")
# Convert to tensors
mobile_tensor = torch.from_numpy(np.stack(mobile_coords, axis=0)).unsqueeze(
0
) # [1, N, 3]
target_tensor = torch.from_numpy(np.stack(target_coords, axis=0)).unsqueeze(
0
) # [1, N, 3]
mask_tensor = torch.tensor(atom_mask, dtype=torch.bool).unsqueeze(0) # [1, N]
# Compute LDDT using existing infrastructure
lddt_value = compute_lddt(
all_atom_pred_pos=mobile_tensor,
all_atom_positions=target_tensor,
all_atom_mask=mask_tensor,
per_residue=False, # Return overall LDDT score
**kwargs,
)
return float(lddt_value)
def state_dict(self):
"""This state dict is optimized for storage, so it turns things to fp16 whenever
possible and converts numpy arrays to lists for JSON serialization.
"""
dct = {k: v for k, v in vars(self).items()}
for k, v in dct.items():
if isinstance(v, np.ndarray):
match v.dtype:
case np.int64:
dct[k] = v.astype(np.int32).tolist()
case np.float64 | np.float32:
dct[k] = v.astype(np.float16).tolist()
case _:
dct[k] = v.tolist()
elif isinstance(v, MolecularComplexMetadata):
dct[k] = asdict(v)
return dct
def to_blob(self) -> bytes:
return brotli.compress(msgpack.dumps(self.state_dict()), quality=5)
@classmethod
def from_state_dict(cls, dct):
for k, v in dct.items():
if isinstance(v, list) and k in [
"atom_positions",
"atom_elements",
"token_to_atoms",
"plddt",
]:
dct[k] = np.array(v)
for k, v in dct.items():
if isinstance(v, np.ndarray):
if k in ["atom_positions", "plddt"]:
dct[k] = v.astype(np.float32)
elif k in ["token_to_atoms"]:
dct[k] = v.astype(np.int32)
dct["metadata"] = MolecularComplexMetadata(**dct["metadata"])
return cls(**dct)
@classmethod
def from_blob(cls, input: Path | str | io.BytesIO | bytes):
match input:
case Path() | str():
bytes = Path(input).read_bytes()
case io.BytesIO():
bytes = input.getvalue()
case _:
bytes = input
return cls.from_state_dict(
msgpack.loads(brotli.decompress(bytes), strict_map_key=False)
)

View File

@@ -25,13 +25,21 @@ from esm.utils.misc import slice_python_object_as_numpy
from esm.utils.structure.affine3d import Affine3D
from esm.utils.structure.aligner import Aligner
from esm.utils.structure.atom_indexer import AtomIndexer
from esm.utils.structure.metrics import compute_gdt_ts, compute_lddt_ca
from esm.utils.structure.mmcif_parsing import MmcifWrapper, Residue
from esm.utils.structure.metrics import (
compute_gdt_ts,
compute_lddt_ca,
)
from esm.utils.structure.mmcif_parsing import (
MmcifWrapper,
Residue,
)
from esm.utils.structure.normalize_coordinates import (
apply_frame_to_coords,
get_protein_normalization_frame,
)
from esm.utils.structure.protein_structure import index_by_atom_name
from esm.utils.structure.protein_structure import (
index_by_atom_name,
)
from esm.utils.types import PathOrBuffer
msgpack_numpy.patch()
@@ -393,6 +401,7 @@ class ProteinChain:
bytes = input
return cls.from_state_dict(msgpack.loads(brotli.decompress(bytes)))
def sasa(self, by_residue: bool = True):
arr = self.atom_array_no_insertions
sasa_per_atom = bs.sasa(arr) # type: ignore
@@ -698,6 +707,7 @@ class ProteinChain:
)
return float(gdt_ts) if gdt_ts.numel() == 1 else gdt_ts.numpy().flatten()
@classmethod
def chain_iterable_from_mmcif(
cls,

View File

@@ -32,8 +32,14 @@ from esm.utils.misc import slice_python_object_as_numpy
from esm.utils.structure.affine3d import Affine3D
from esm.utils.structure.aligner import Aligner
from esm.utils.structure.atom_indexer import AtomIndexer
from esm.utils.structure.metrics import compute_gdt_ts, compute_lddt_ca
from esm.utils.structure.mmcif_parsing import MmcifWrapper, NoProteinError
from esm.utils.structure.metrics import (
compute_gdt_ts,
compute_lddt_ca,
)
from esm.utils.structure.mmcif_parsing import (
MmcifWrapper,
NoProteinError,
)
from esm.utils.structure.protein_chain import (
ProteinChain,
chain_to_ndarray,

45
esm/utils/system.py Normal file
View File

@@ -0,0 +1,45 @@
import io
import subprocess
import typing as T
from pathlib import Path
PathLike = T.Union[str, Path]
PathOrBuffer = T.Union[PathLike, io.StringIO]
def run_subprocess_with_errorcheck(
*popenargs,
capture_output: bool = False,
quiet: bool = False,
env: dict[str, str] | None = None,
shell: bool = False,
executable: str | None = None,
**kws,
) -> subprocess.CompletedProcess:
"""A command similar to subprocess.run, however the errormessage will
contain the stderr when using this function. This makes it significantly
easier to diagnose issues.
"""
try:
if capture_output:
stdout = subprocess.PIPE
elif quiet:
stdout = subprocess.DEVNULL
else:
stdout = None
p = subprocess.run(
*popenargs,
stderr=subprocess.PIPE,
stdout=stdout,
check=True,
env=env,
shell=shell,
executable=executable,
**kws,
)
except subprocess.CalledProcessError as e:
raise RuntimeError(
f"Command failed with errorcode {e.returncode}." f"\n\n{e.stderr.decode()}"
)
return p

View File

@@ -4,7 +4,9 @@ import pygtrie
from ipywidgets import widgets
from esm.sdk.api import FunctionAnnotation
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer,
)
TRIE: pygtrie.CharTrie | None = None

View File

@@ -7,11 +7,15 @@ import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
from esm.sdk.api import ESMProtein
from esm.widgets.utils.drawing.draw_category_array import draw_data_array
from esm.widgets.utils.drawing.draw_category_array import (
draw_data_array,
)
from esm.widgets.utils.drawing.draw_function_annotations import (
draw_function_annotations,
)
from esm.widgets.utils.drawing.draw_protein_structure import draw_protein_structure
from esm.widgets.utils.drawing.draw_protein_structure import (
draw_protein_structure,
)
from esm.widgets.utils.serialization import (
create_download_button_from_buffer,
protein_to_pdb_buffer,

View File

@@ -3,9 +3,16 @@ from typing import Any, Callable, Sequence
import ipywidgets as widgets
from esm.utils.structure.protein_chain import ProteinChain
from esm.widgets.utils.drawing.colors import hex_to_rgba_tuple, rgba_tuple_to_hex
from esm.widgets.utils.drawing.draw_category_array import draw_data_array
from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges
from esm.widgets.utils.drawing.colors import (
hex_to_rgba_tuple,
rgba_tuple_to_hex,
)
from esm.widgets.utils.drawing.draw_category_array import (
draw_data_array,
)
from esm.widgets.utils.parsing import (
convert_range_string_to_list_of_ranges,
)
from esm.widgets.utils.prompting import PromptManager

View File

@@ -4,9 +4,16 @@ import ipywidgets as widgets
import pydssp
from esm.utils.structure.protein_chain import ProteinChain
from esm.widgets.utils.drawing.colors import hex_to_rgba_tuple, rgba_tuple_to_hex
from esm.widgets.utils.drawing.draw_category_array import draw_data_array
from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges
from esm.widgets.utils.drawing.colors import (
hex_to_rgba_tuple,
rgba_tuple_to_hex,
)
from esm.widgets.utils.drawing.draw_category_array import (
draw_data_array,
)
from esm.widgets.utils.parsing import (
convert_range_string_to_list_of_ranges,
)
from esm.widgets.utils.prompting import PromptManager

View File

@@ -6,7 +6,9 @@ from esm.widgets.utils.drawing.colors import (
hex_to_rgba_tuple,
rgba_tuple_to_rgba_html_string,
)
from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges
from esm.widgets.utils.parsing import (
convert_range_string_to_list_of_ranges,
)
from esm.widgets.utils.prompting import PromptManager

View File

@@ -10,8 +10,12 @@ from matplotlib.patches import Rectangle
from esm.utils.structure.protein_chain import ProteinChain
from esm.widgets.utils import indexing
from esm.widgets.utils.drawing.draw_protein_structure import draw_protein_structure
from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges
from esm.widgets.utils.drawing.draw_protein_structure import (
draw_protein_structure,
)
from esm.widgets.utils.parsing import (
convert_range_string_to_list_of_ranges,
)
from esm.widgets.utils.printing import wrapped_print
from esm.widgets.utils.prompting import PromptManager

View File

@@ -9,7 +9,10 @@ from matplotlib import colormaps
from PIL import Image
from esm.sdk.api import FunctionAnnotation
from esm.utils.function.interpro import InterPro, InterProEntryType
from esm.utils.function.interpro import (
InterPro,
InterProEntryType,
)
@contextmanager

View File

@@ -9,7 +9,9 @@ from esm.sdk.api import ESMProtein, FunctionAnnotation
from esm.utils import encoding
from esm.widgets.utils import indexing
from esm.widgets.utils.drawing.colors import rgba_tuple_to_hex
from esm.widgets.utils.drawing.draw_category_array import draw_data_array
from esm.widgets.utils.drawing.draw_category_array import (
draw_data_array,
)
from esm.widgets.utils.printing import wrapped_print

View File

@@ -13,9 +13,13 @@ from esm.sdk.api import (
GenerationConfig,
)
from esm.utils.constants import models
from esm.widgets.components.results_visualizer import create_results_visualizer
from esm.widgets.components.results_visualizer import (
create_results_visualizer,
)
from esm.widgets.utils.printing import wrapped_print
from esm.widgets.utils.serialization import create_download_results_button
from esm.widgets.utils.serialization import (
create_download_results_button,
)
def create_esm3_generation_launcher(

View File

@@ -1,6 +1,8 @@
from ipywidgets import widgets
from esm.widgets.components.sasa_prompt_selector import create_sasa_prompt_selector
from esm.widgets.components.sasa_prompt_selector import (
create_sasa_prompt_selector,
)
from esm.widgets.components.secondary_structure_prompt_selector import (
create_secondary_structure_prompt_selector,
)

View File

@@ -4,12 +4,20 @@ from ipywidgets import widgets
from esm.sdk.api import ESM3InferenceClient, ESMProtein
from esm.utils.constants import esm3 as C
from esm.widgets.components.function_annotator import create_function_annotator
from esm.widgets.components.function_annotator import (
create_function_annotator,
)
from esm.widgets.utils.prompting import PromptManagerCollection
from esm.widgets.utils.protein_import import ProteinImporter
from esm.widgets.views.esm3_generation_launcher import create_esm3_generation_launcher
from esm.widgets.views.esm3_prompt_preview import create_esm3_prompt_preview
from esm.widgets.views.esm3_prompt_selector import create_esm3_prompt_selector
from esm.widgets.views.esm3_generation_launcher import (
create_esm3_generation_launcher,
)
from esm.widgets.views.esm3_prompt_preview import (
create_esm3_prompt_preview,
)
from esm.widgets.views.esm3_prompt_selector import (
create_esm3_prompt_selector,
)
def create_generation_ui(

View File

@@ -6,7 +6,9 @@ from esm.sdk.api import (
ESMProteinError,
GenerationConfig,
)
from esm.widgets.components.results_visualizer import create_results_visualizer
from esm.widgets.components.results_visualizer import (
create_results_visualizer,
)
from esm.widgets.utils.printing import wrapped_print
from esm.widgets.utils.protein_import import ProteinImporter

View File

@@ -4,7 +4,10 @@ from textwrap import dedent
from ipywidgets import widgets
from esm.widgets.utils.clients import get_forge_client, get_local_client
from esm.widgets.utils.clients import (
get_forge_client,
get_local_client,
)
from esm.widgets.utils.types import ClientInitContainer

View File

@@ -6,7 +6,9 @@ from esm.sdk.api import (
ESMProteinError,
GenerationConfig,
)
from esm.widgets.components.results_visualizer import create_results_visualizer
from esm.widgets.components.results_visualizer import (
create_results_visualizer,
)
from esm.widgets.utils.printing import wrapped_print
from esm.widgets.utils.protein_import import ProteinImporter

View File

@@ -1,6 +1,6 @@
[project]
name = "esm"
version = "3.2.2"
version = "3.2.2.post1"
description = "EvolutionaryScale open model repository"
readme = "README.md"
requires-python = ">=3.12,<3.13"
@@ -45,7 +45,6 @@ dependencies = [
"pygtrie",
"dna_features_viewer",
]
# Pytest
[tool.pytest.ini_options]
addopts = """

View File

@@ -1,3 +1,2 @@
esm
esm >=3.2.1post1,<4.0.0
pytest
httpx # TODO(williamxi): Remove this after the esm repo is fixed

View File

@@ -1,7 +1,7 @@
import os
import pytest
import torch
from esm.sdk import client # pyright: ignore
from esm.sdk.api import ( # pyright: ignore
ESMProtein,
@@ -37,6 +37,7 @@ def test_oss_esm3_client():
logits_config = LogitsConfig(sequence=True, return_embeddings=True)
result = esm3_client.logits(input=encoded_protein, config=logits_config)
assert isinstance(result, LogitsOutput)
assert isinstance(result.logits.sequence, torch.Tensor)
sampling_config = SamplingConfig(sequence=SamplingTrackConfig(temperature=0.1))
result = esm3_client.forward_and_sample(
@@ -53,7 +54,7 @@ def test_oss_esm3_client():
def test_oss_esmc_client():
assert URL is not None
sequence = "MALWMRLLPLLALLALAVUUPDPAAA"
sequence = "MALWMRLLPLLALLALAVPDPAAA"
model = "esmc-300m-2024-12"
esmc_client = client(model=model, url=URL, token=API_TOKEN)
@@ -69,13 +70,14 @@ def test_oss_esmc_client():
)
result = esmc_client.logits(input=encoded_protein, config=logits_config)
assert isinstance(result, LogitsOutput)
assert isinstance(result.logits.sequence, torch.Tensor)
@pytest.mark.sdk
def test_oss_sequence_structure_forge_inference_client():
assert URL is not None
sequence = "MALWMRLLPLLALLALAVUUPDPAAA"
sequence = "MALWMRLLPLLALLALAVPDPAAA"
model = "esm3-small-2024-03"
client = SequenceStructureForgeInferenceClient(
model=model, url=URL, token=API_TOKEN