This commit is contained in:
Neil Thomas
2025-09-19 21:46:14 +00:00
parent 0382151104
commit 34c6638b58
49 changed files with 339 additions and 102 deletions

View File

@@ -38,7 +38,6 @@
"\n",
"!pip install py3Dmol\n",
"import py3Dmol\n",
"\n",
"from esm.models.esm3 import ESM3\n",
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
"from esm.utils.structure.protein_chain import ProteinChain"

View File

@@ -2,7 +2,6 @@ import random
import torch
import torch.nn.functional as F
from esm.pretrained import (
ESM3_function_decoder_v0,
ESM3_sm_open_v0,
@@ -13,7 +12,9 @@ from esm.tokenization import get_esm3_model_tokenizers
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer as EsmFunctionTokenizer,
)
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
from esm.tokenization.sequence_tokenizer import (
EsmSequenceTokenizer,
)
from esm.utils.structure.protein_chain import ProteinChain
from esm.utils.types import FunctionAnnotation

View File

@@ -2,7 +2,6 @@ import os
from typing import cast
import numpy as np
from esm.sdk.api import (
ESM3InferenceClient,
ESMProtein,

View File

@@ -72,7 +72,6 @@
"outputs": [],
"source": [
"from biotite.database import rcsb\n",
"\n",
"from esm.sdk.api import ESMProtein\n",
"from esm.utils.structure.protein_chain import ProteinChain\n",
"from esm.utils.types import FunctionAnnotation\n",
@@ -497,9 +496,8 @@
"# Functions for visualizing InterPro function annotations\n",
"\n",
"from dna_features_viewer import GraphicFeature, GraphicRecord\n",
"from matplotlib import colormaps\n",
"\n",
"from esm.utils.function.interpro import InterPro, InterProEntryType\n",
"from matplotlib import colormaps\n",
"\n",
"\n",
"def visualize_function_annotations(\n",

View File

@@ -64,7 +64,6 @@
"import matplotlib.pyplot as pl\n",
"import py3Dmol\n",
"import torch\n",
"\n",
"from esm.sdk import client\n",
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
"from esm.utils.structure.protein_chain import ProteinChain"

View File

@@ -36,7 +36,6 @@
"\n",
"!pip install py3Dmol\n",
"import py3Dmol\n",
"\n",
"from esm.sdk import client\n",
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
"from esm.utils.structure.protein_chain import ProteinChain"

View File

@@ -49,7 +49,6 @@
"source": [
"import biotite.structure as bs\n",
"import py3Dmol\n",
"\n",
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
"from esm.sdk.experimental import ESM3GuidedDecoding, GuidedDecodingScoringFunction"
]

View File

@@ -1 +1,2 @@
__version__ = "3.2.2.post2"

View File

@@ -5,7 +5,10 @@ import torch
import torch.nn.functional as F
from torch import nn
from esm.layers.rotary import RotaryEmbedding, TritonRotaryEmbedding
from esm.layers.rotary import (
RotaryEmbedding,
TritonRotaryEmbedding,
)
try:
from flash_attn import flash_attn_varlen_qkvpacked_func # type: ignore

View File

@@ -2,8 +2,13 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from esm.layers.attention import FlashMultiHeadAttention, MultiHeadAttention
from esm.layers.geom_attention import GeometricReasoningOriginalImpl
from esm.layers.attention import (
FlashMultiHeadAttention,
MultiHeadAttention,
)
from esm.layers.geom_attention import (
GeometricReasoningOriginalImpl,
)
from esm.utils.structure.affine3d import Affine3D

View File

@@ -2,7 +2,10 @@ import torch
import torch.nn as nn
from esm.utils.constants.physics import BB_COORDINATES
from esm.utils.structure.affine3d import Affine3D, RotationMatrix
from esm.utils.structure.affine3d import (
Affine3D,
RotationMatrix,
)
class Dim6RotStructureHead(nn.Module):

View File

@@ -13,7 +13,10 @@ from attr import dataclass
from esm.layers.regression_head import RegressionHead
from esm.layers.transformer_stack import TransformerStack
from esm.models.function_decoder import FunctionTokenDecoder
from esm.models.vqvae import StructureTokenDecoder, StructureTokenEncoder
from esm.models.vqvae import (
StructureTokenDecoder,
StructureTokenEncoder,
)
from esm.sdk.api import (
ESM3InferenceClient,
ESMProtein,
@@ -29,7 +32,10 @@ from esm.sdk.api import (
from esm.tokenization import TokenizerCollectionProtocol
from esm.utils import encoding
from esm.utils.constants import esm3 as C
from esm.utils.constants.models import ESM3_OPEN_SMALL, normalize_model_name
from esm.utils.constants.models import (
ESM3_OPEN_SMALL,
normalize_model_name,
)
from esm.utils.decoding import decode_protein_tensor
from esm.utils.generation import (
_batch_forward,
@@ -44,7 +50,9 @@ from esm.utils.sampling import (
get_default_sampling_config,
validate_sampling_config,
)
from esm.utils.structure.affine3d import build_affine3d_from_coordinates
from esm.utils.structure.affine3d import (
build_affine3d_from_coordinates,
)
@dataclass

View File

@@ -12,7 +12,9 @@ from cloudpathlib import AnyPath
from esm.layers.regression_head import RegressionHead
from esm.layers.transformer_stack import TransformerStack
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer,
)
from esm.utils.constants import esm3 as C
from esm.utils.misc import merge_annotations, merge_ranges
from esm.utils.types import FunctionAnnotation

View File

@@ -7,7 +7,10 @@ from esm.layers.structure_proj import Dim6RotStructureHead
from esm.layers.transformer_stack import TransformerStack
from esm.utils.constants import esm3 as C
from esm.utils.misc import knn_graph
from esm.utils.structure.affine3d import Affine3D, build_affine3d_from_coordinates
from esm.utils.structure.affine3d import (
Affine3D,
build_affine3d_from_coordinates,
)
from esm.utils.structure.predicted_aligned_error import (
compute_predicted_aligned_error,
compute_tm,

View File

@@ -6,8 +6,14 @@ import torch.nn as nn
from esm.models.esm3 import ESM3
from esm.models.esmc import ESMC
from esm.models.function_decoder import FunctionTokenDecoder
from esm.models.vqvae import StructureTokenDecoder, StructureTokenEncoder
from esm.tokenization import get_esm3_model_tokenizers, get_esmc_model_tokenizers
from esm.models.vqvae import (
StructureTokenDecoder,
StructureTokenEncoder,
)
from esm.tokenization import (
get_esm3_model_tokenizers,
get_esmc_model_tokenizers,
)
from esm.utils.constants.esm3 import data_root
from esm.utils.constants.models import (
ESM3_FUNCTION_DECODER_V0,

View File

@@ -2,19 +2,27 @@ from __future__ import annotations
from abc import ABC
from copy import deepcopy
from typing import Sequence
from typing import List, Sequence
import attr
import torch
from attr import asdict, define
import esm.utils.constants.api as C
from esm.tokenization import TokenizerCollectionProtocol, get_esm3_model_tokenizers
from esm.tokenization import (
TokenizerCollectionProtocol,
get_esm3_model_tokenizers,
)
from esm.utils import encoding
from esm.utils.constants.models import ESM3_OPEN_SMALL
from esm.utils.misc import get_chainbreak_boundaries_from_sequence
from esm.utils.misc import (
get_chainbreak_boundaries_from_sequence,
)
from esm.utils.structure.protein_chain import ProteinChain
from esm.utils.structure.protein_complex import SINGLE_LETTER_CHAIN_IDS, ProteinComplex
from esm.utils.structure.protein_complex import (
SINGLE_LETTER_CHAIN_IDS,
ProteinComplex,
)
from esm.utils.types import FunctionAnnotation, PathOrBuffer
@@ -35,6 +43,7 @@ class ESMProtein(ProteinType):
plddt: torch.Tensor | None = None
ptm: torch.Tensor | None = None
# When calling EvolutionaryScale API, use this flag to disclose any
# sequences that may potentially have concerns.
# Such sequences may not go through standard safety filter for approved users.
@@ -327,6 +336,8 @@ class InverseFoldingConfig:
temperature: float = 1.0
## Low Level Endpoint Types
@define
class SamplingTrackConfig:
@@ -391,6 +402,9 @@ class LogitsConfig:
ith_hidden_layer: int = -1
@define
class LogitsOutput:
logits: ForwardTrackData | None = None

View File

@@ -1,9 +1,13 @@
import asyncio
import time
from abc import ABC, abstractmethod
from typing import Any
from urllib.parse import urljoin
import httpx
from esm.sdk.api import ESMProteinError
from esm.sdk.retry import retry_decorator
from esm.utils.decoding import assemble_message
@@ -112,7 +116,10 @@ class _BaseForgeInferenceClient:
):
try:
request, headers = self.prepare_request(
request, potential_sequence_of_concern, return_bytes, headers
request,
potential_sequence_of_concern,
return_bytes,
headers,
)
response = await self.async_client.post(
url=urljoin(self.url, f"/api/v1/{endpoint}"),
@@ -142,7 +149,10 @@ class _BaseForgeInferenceClient:
):
try:
request, headers = self.prepare_request(
request, potential_sequence_of_concern, return_bytes, headers
request,
potential_sequence_of_concern,
return_bytes,
headers,
)
response = self.client.post(
url=urljoin(self.url, f"/api/v1/{endpoint}"),
@@ -160,3 +170,5 @@ class _BaseForgeInferenceClient:
error_code=500,
error_msg=f"Failed to submit request to {endpoint}. Error: {e}",
)

View File

@@ -4,7 +4,7 @@ import asyncio
import base64
import pickle
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Sequence
from typing import Any, Literal, Sequence, cast
import torch
@@ -20,14 +20,21 @@ from esm.sdk.api import (
InverseFoldingConfig,
LogitsConfig,
LogitsOutput,
ProteinChain,
ProteinType,
SamplingConfig,
SamplingTrackConfig,
)
from esm.sdk.base_forge_client import _BaseForgeInferenceClient
from esm.sdk.base_forge_client import (
_BaseForgeInferenceClient,
)
from esm.sdk.retry import retry_decorator
from esm.utils.constants.api import MIMETYPE_ES_PICKLE
from esm.utils.misc import deserialize_tensors, maybe_list, maybe_tensor
from esm.utils.misc import (
deserialize_tensors,
maybe_list,
maybe_tensor,
)
from esm.utils.msa import MSA
from esm.utils.structure.input_builder import (
StructurePredictionInput,
@@ -101,9 +108,13 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
)
@staticmethod
def _process_fold_request(sequence: str, model_name: str | None):
def _process_fold_request(
sequence: str,
model_name: str | None,
):
request: dict[str, Any] = {"sequence": sequence}
request["model"] = model_name
return request
@@ -138,6 +149,7 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
return request
async def _async_fetch_msa(self, sequence: str) -> MSA:
print("Fetching MSA ... this may take a few minutes")
# Accept both "|" and ":" as the chainbreak token.
@@ -176,11 +188,15 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
del potential_sequence_of_concern
request = self._process_fold_request(
sequence, model_name if model_name is not None else self.model
sequence,
model_name if model_name is not None else self.model,
)
try:
data = await self._async_post("fold", request)
data = await self._async_post(
"fold",
request,
)
except ESMProteinError as e:
return e
@@ -207,11 +223,15 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
del potential_sequence_of_concern
request = self._process_fold_request(
sequence, model_name if model_name is not None else self.model
sequence,
model_name if model_name is not None else self.model,
)
try:
data = self._post("fold", request)
data = self._post(
"fold",
request,
)
except ESMProteinError as e:
return e
@@ -219,7 +239,9 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
@retry_decorator
async def async_fold_all_atom(
self, all_atom_input: StructurePredictionInput, model_name: str | None = None
self,
all_atom_input: StructurePredictionInput,
model_name: str | None = None,
) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError:
"""Fold a molecular complex containing proteins, nucleic acids, and/or ligands.
@@ -228,11 +250,15 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
model_name: Override the client level model name if needed
"""
request = self._process_fold_all_atom_request(
all_atom_input, model_name if model_name is not None else self.model
all_atom_input,
model_name if model_name is not None else self.model,
)
try:
data = await self._async_post("fold_all_atom", request)
data = await self._async_post(
"fold_all_atom",
request,
)
except ESMProteinError as e:
return e
@@ -240,7 +266,9 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
@retry_decorator
def fold_all_atom(
self, all_atom_input: StructurePredictionInput, model_name: str | None = None
self,
all_atom_input: StructurePredictionInput,
model_name: str | None = None,
) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError:
"""Predict coordinates for a molecular complex containing proteins, dna, rna, and/or ligands.
@@ -249,11 +277,15 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
model_name: Override the client level model name if needed
"""
request = self._process_fold_all_atom_request(
all_atom_input, model_name if model_name is not None else self.model
all_atom_input,
model_name if model_name is not None else self.model,
)
try:
data = self._post("fold_all_atom", request)
data = self._post(
"fold_all_atom",
request,
)
except ESMProteinError as e:
return e
@@ -261,13 +293,15 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
@staticmethod
def _process_fold_all_atom_request(
all_atom_input: StructurePredictionInput, model_name: str | None = None
all_atom_input: StructurePredictionInput,
model_name: str | None = None,
) -> dict[str, Any]:
request: dict[str, Any] = {
"all_atom_input": serialize_structure_prediction_input(all_atom_input),
"model": model_name,
}
return request
@staticmethod
@@ -352,6 +386,7 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
return ESMProtein(sequence=data["sequence"])
class ESM3ForgeInferenceClient(ESM3InferenceClient, _BaseForgeInferenceClient):
def __init__(
self,
@@ -1177,3 +1212,5 @@ class ESMCForgeInferenceClient(ESMCInferenceClient, _BaseForgeInferenceClient):
raise NotImplementedError(
f"Can not get underlying remote model {self.model} from a Forge client."
)

View File

@@ -1,7 +1,10 @@
from dataclasses import dataclass
from typing import Protocol
from esm.utils.constants.models import ESM3_OPEN_SMALL, normalize_model_name
from esm.utils.constants.models import (
ESM3_OPEN_SMALL,
normalize_model_name,
)
from .function_tokenizer import InterProQuantizedTokenizer
from .residue_tokenizer import ResidueAnnotationsTokenizer

View File

@@ -10,12 +10,24 @@ from esm.models.function_decoder import FunctionTokenDecoder
from esm.models.vqvae import StructureTokenDecoder
from esm.sdk.api import ESMProtein, ESMProteinTensor
from esm.tokenization import TokenizerCollectionProtocol
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
from esm.tokenization.residue_tokenizer import ResidueAnnotationsTokenizer
from esm.tokenization.sasa_tokenizer import SASADiscretizingTokenizer
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
from esm.tokenization.ss_tokenizer import SecondaryStructureTokenizer
from esm.tokenization.structure_tokenizer import StructureTokenizer
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer,
)
from esm.tokenization.residue_tokenizer import (
ResidueAnnotationsTokenizer,
)
from esm.tokenization.sasa_tokenizer import (
SASADiscretizingTokenizer,
)
from esm.tokenization.sequence_tokenizer import (
EsmSequenceTokenizer,
)
from esm.tokenization.ss_tokenizer import (
SecondaryStructureTokenizer,
)
from esm.tokenization.structure_tokenizer import (
StructureTokenizer,
)
from esm.tokenization.tokenizer_base import EsmTokenizerBase
from esm.utils.constants import api as api_constants
from esm.utils.constants import esm3 as C

View File

@@ -7,13 +7,26 @@ from esm.models.vqvae import StructureTokenEncoder
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer as EsmFunctionTokenizer,
)
from esm.tokenization.residue_tokenizer import ResidueAnnotationsTokenizer
from esm.tokenization.sasa_tokenizer import SASADiscretizingTokenizer
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
from esm.tokenization.ss_tokenizer import SecondaryStructureTokenizer
from esm.tokenization.structure_tokenizer import StructureTokenizer
from esm.tokenization.residue_tokenizer import (
ResidueAnnotationsTokenizer,
)
from esm.tokenization.sasa_tokenizer import (
SASADiscretizingTokenizer,
)
from esm.tokenization.sequence_tokenizer import (
EsmSequenceTokenizer,
)
from esm.tokenization.ss_tokenizer import (
SecondaryStructureTokenizer,
)
from esm.tokenization.structure_tokenizer import (
StructureTokenizer,
)
from esm.utils.constants import esm3 as C
from esm.utils.function.encode_decode import encode_function_annotations
from esm.utils.function.encode_decode import (
encode_function_annotations,
)
from esm.utils.structure.protein_chain import ProteinChain
from esm.utils.types import FunctionAnnotation
@@ -152,6 +165,8 @@ def tokenize_function_annotations(
return function_tokens, residue_annotation_tokens
# Tokenized Defaults
def get_default_sequence_tokens(
sequence_length: int, sequence_tokenizer: EsmSequenceTokenizer
@@ -227,3 +242,5 @@ def get_default_residue_annotation_tokens(
residue_annotation_tokens[0] = residue_annotation_tokenizer.bos_token_id
residue_annotation_tokens[-1] = residue_annotation_tokenizer.eos_token_id
return residue_annotation_tokens

View File

@@ -7,7 +7,10 @@ from typing import Any, Callable, List
from tqdm import tqdm
from esm.sdk.api import ESMProteinError
from esm.sdk.retry import retry_if_specific_error, skip_retries_var
from esm.sdk.retry import (
retry_if_specific_error,
skip_retries_var,
)
TQDM_BAR_FORMAT = (
"{desc:<12}{percentage:3.0f}%|{bar:24}| {n_fmt}/{total_fmt} "

View File

@@ -3,9 +3,16 @@ from typing import Sequence
import torch
from esm.models.function_decoder import FunctionTokenDecoder, merge_annotations
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
from esm.tokenization.residue_tokenizer import ResidueAnnotationsTokenizer
from esm.models.function_decoder import (
FunctionTokenDecoder,
merge_annotations,
)
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer,
)
from esm.tokenization.residue_tokenizer import (
ResidueAnnotationsTokenizer,
)
from esm.utils.constants import esm3 as C
from esm.utils.types import FunctionAnnotation

View File

@@ -19,8 +19,13 @@ from esm.sdk.api import (
SamplingConfig,
SamplingTrackConfig,
)
from esm.tokenization import EsmTokenizerBase, TokenizerCollectionProtocol
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
from esm.tokenization import (
EsmTokenizerBase,
TokenizerCollectionProtocol,
)
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer,
)
from esm.utils.constants import esm3 as C
from esm.utils.misc import stack_variable_length_tensors
from esm.utils.noise_schedules import NOISE_SCHEDULE_REGISTRY

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import os
from collections import defaultdict
from contextlib import nullcontext
from dataclasses import is_dataclass
from io import BytesIO
from typing import (
@@ -261,7 +262,7 @@ def unbinpack(
return stack_variable_length_tensors(unpacked_tensors, pad_value)
def fp32_autocast_context(device_type: str) -> ContextManager[torch.amp.autocast]: # type: ignore
def fp32_autocast_context(device_type: str) -> ContextManager[Any]: # type: ignore
"""
Returns an autocast context manager that disables downcasting by AMP.
@@ -273,6 +274,9 @@ def fp32_autocast_context(device_type: str) -> ContextManager[torch.amp.autocast
"""
if device_type == "cpu":
return torch.amp.autocast(device_type, enabled=False) # type: ignore
elif device_type == "mps":
# For MPS, just return a no-op context manager (nullcontext) since MPS does not support autocast.
return nullcontext()
elif device_type == "cuda":
return torch.amp.autocast(device_type, dtype=torch.float32) # type: ignore
else:

View File

@@ -1,3 +1,7 @@
from esm.utils.msa.msa import MSA, FastMSA, remove_insertions_from_sequence
from esm.utils.msa.msa import (
MSA,
FastMSA,
remove_insertions_from_sequence,
)
__all__ = ["MSA", "FastMSA", "remove_insertions_from_sequence"]

View File

@@ -12,8 +12,15 @@ from Bio import SeqIO
from scipy.spatial.distance import cdist
from esm.utils.misc import slice_any_object
from esm.utils.msa.filter_sequences import greedy_select_indices, hhfilter
from esm.utils.parsing import FastaEntry, read_sequences, write_sequences
from esm.utils.msa.filter_sequences import (
greedy_select_indices,
hhfilter,
)
from esm.utils.parsing import (
FastaEntry,
read_sequences,
write_sequences,
)
from esm.utils.sequential_dataclass import SequentialDataclass
from esm.utils.system import PathOrBuffer

View File

@@ -5,9 +5,18 @@ import attr
import torch
import torch.nn.functional as F
from esm.sdk.api import ESMProteinTensor, SamplingConfig, SamplingTrackConfig
from esm.tokenization import TokenizerCollectionProtocol, get_invalid_tokenizer_ids
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
from esm.sdk.api import (
ESMProteinTensor,
SamplingConfig,
SamplingTrackConfig,
)
from esm.tokenization import (
TokenizerCollectionProtocol,
get_invalid_tokenizer_ids,
)
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer,
)
from esm.utils.constants.esm3 import (
MAX_RESIDUE_ANNOTATIONS,
SASA_DISCRETIZATION_BOUNDARIES,

View File

@@ -6,7 +6,9 @@ from typing import Any, ClassVar, Protocol, TypeVar
import numpy as np
import torch
from esm.utils.structure.protein_structure import compute_affine_and_rmsd
from esm.utils.structure.protein_structure import (
compute_affine_and_rmsd,
)
class Alignable(Protocol):

View File

@@ -1,6 +1,8 @@
import numpy as np
from esm.utils.structure.protein_structure import index_by_atom_name
from esm.utils.structure.protein_structure import (
index_by_atom_name,
)
class AtomIndexer:

View File

@@ -4,6 +4,7 @@ from typing import Any, Sequence
import numpy as np
@dataclass
class Modification:
position: int # zero-indexed

View File

@@ -16,8 +16,14 @@ import numpy as np
import torch
from esm.utils import residue_constants
from esm.utils.structure.metrics import compute_lddt, compute_rmsd
from esm.utils.structure.protein_complex import ProteinComplex, ProteinComplexMetadata
from esm.utils.structure.metrics import (
compute_lddt,
compute_rmsd,
)
from esm.utils.structure.protein_complex import (
ProteinComplex,
ProteinComplexMetadata,
)
@dataclass

View File

@@ -25,13 +25,21 @@ from esm.utils.misc import slice_python_object_as_numpy
from esm.utils.structure.affine3d import Affine3D
from esm.utils.structure.aligner import Aligner
from esm.utils.structure.atom_indexer import AtomIndexer
from esm.utils.structure.metrics import compute_gdt_ts, compute_lddt_ca
from esm.utils.structure.mmcif_parsing import MmcifWrapper, Residue
from esm.utils.structure.metrics import (
compute_gdt_ts,
compute_lddt_ca,
)
from esm.utils.structure.mmcif_parsing import (
MmcifWrapper,
Residue,
)
from esm.utils.structure.normalize_coordinates import (
apply_frame_to_coords,
get_protein_normalization_frame,
)
from esm.utils.structure.protein_structure import index_by_atom_name
from esm.utils.structure.protein_structure import (
index_by_atom_name,
)
from esm.utils.types import PathOrBuffer
msgpack_numpy.patch()
@@ -393,6 +401,7 @@ class ProteinChain:
bytes = input
return cls.from_state_dict(msgpack.loads(brotli.decompress(bytes)))
def sasa(self, by_residue: bool = True):
arr = self.atom_array_no_insertions
sasa_per_atom = bs.sasa(arr) # type: ignore
@@ -698,6 +707,7 @@ class ProteinChain:
)
return float(gdt_ts) if gdt_ts.numel() == 1 else gdt_ts.numpy().flatten()
@classmethod
def chain_iterable_from_mmcif(
cls,

View File

@@ -32,8 +32,14 @@ from esm.utils.misc import slice_python_object_as_numpy
from esm.utils.structure.affine3d import Affine3D
from esm.utils.structure.aligner import Aligner
from esm.utils.structure.atom_indexer import AtomIndexer
from esm.utils.structure.metrics import compute_gdt_ts, compute_lddt_ca
from esm.utils.structure.mmcif_parsing import MmcifWrapper, NoProteinError
from esm.utils.structure.metrics import (
compute_gdt_ts,
compute_lddt_ca,
)
from esm.utils.structure.mmcif_parsing import (
MmcifWrapper,
NoProteinError,
)
from esm.utils.structure.protein_chain import (
ProteinChain,
chain_to_ndarray,

View File

@@ -4,7 +4,9 @@ import pygtrie
from ipywidgets import widgets
from esm.sdk.api import FunctionAnnotation
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer,
)
TRIE: pygtrie.CharTrie | None = None

View File

@@ -7,11 +7,15 @@ import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
from esm.sdk.api import ESMProtein
from esm.widgets.utils.drawing.draw_category_array import draw_data_array
from esm.widgets.utils.drawing.draw_category_array import (
draw_data_array,
)
from esm.widgets.utils.drawing.draw_function_annotations import (
draw_function_annotations,
)
from esm.widgets.utils.drawing.draw_protein_structure import draw_protein_structure
from esm.widgets.utils.drawing.draw_protein_structure import (
draw_protein_structure,
)
from esm.widgets.utils.serialization import (
create_download_button_from_buffer,
protein_to_pdb_buffer,

View File

@@ -3,9 +3,16 @@ from typing import Any, Callable, Sequence
import ipywidgets as widgets
from esm.utils.structure.protein_chain import ProteinChain
from esm.widgets.utils.drawing.colors import hex_to_rgba_tuple, rgba_tuple_to_hex
from esm.widgets.utils.drawing.draw_category_array import draw_data_array
from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges
from esm.widgets.utils.drawing.colors import (
hex_to_rgba_tuple,
rgba_tuple_to_hex,
)
from esm.widgets.utils.drawing.draw_category_array import (
draw_data_array,
)
from esm.widgets.utils.parsing import (
convert_range_string_to_list_of_ranges,
)
from esm.widgets.utils.prompting import PromptManager

View File

@@ -4,9 +4,16 @@ import ipywidgets as widgets
import pydssp
from esm.utils.structure.protein_chain import ProteinChain
from esm.widgets.utils.drawing.colors import hex_to_rgba_tuple, rgba_tuple_to_hex
from esm.widgets.utils.drawing.draw_category_array import draw_data_array
from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges
from esm.widgets.utils.drawing.colors import (
hex_to_rgba_tuple,
rgba_tuple_to_hex,
)
from esm.widgets.utils.drawing.draw_category_array import (
draw_data_array,
)
from esm.widgets.utils.parsing import (
convert_range_string_to_list_of_ranges,
)
from esm.widgets.utils.prompting import PromptManager

View File

@@ -6,7 +6,9 @@ from esm.widgets.utils.drawing.colors import (
hex_to_rgba_tuple,
rgba_tuple_to_rgba_html_string,
)
from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges
from esm.widgets.utils.parsing import (
convert_range_string_to_list_of_ranges,
)
from esm.widgets.utils.prompting import PromptManager

View File

@@ -10,8 +10,12 @@ from matplotlib.patches import Rectangle
from esm.utils.structure.protein_chain import ProteinChain
from esm.widgets.utils import indexing
from esm.widgets.utils.drawing.draw_protein_structure import draw_protein_structure
from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges
from esm.widgets.utils.drawing.draw_protein_structure import (
draw_protein_structure,
)
from esm.widgets.utils.parsing import (
convert_range_string_to_list_of_ranges,
)
from esm.widgets.utils.printing import wrapped_print
from esm.widgets.utils.prompting import PromptManager

View File

@@ -9,7 +9,10 @@ from matplotlib import colormaps
from PIL import Image
from esm.sdk.api import FunctionAnnotation
from esm.utils.function.interpro import InterPro, InterProEntryType
from esm.utils.function.interpro import (
InterPro,
InterProEntryType,
)
@contextmanager

View File

@@ -9,7 +9,9 @@ from esm.sdk.api import ESMProtein, FunctionAnnotation
from esm.utils import encoding
from esm.widgets.utils import indexing
from esm.widgets.utils.drawing.colors import rgba_tuple_to_hex
from esm.widgets.utils.drawing.draw_category_array import draw_data_array
from esm.widgets.utils.drawing.draw_category_array import (
draw_data_array,
)
from esm.widgets.utils.printing import wrapped_print

View File

@@ -13,9 +13,13 @@ from esm.sdk.api import (
GenerationConfig,
)
from esm.utils.constants import models
from esm.widgets.components.results_visualizer import create_results_visualizer
from esm.widgets.components.results_visualizer import (
create_results_visualizer,
)
from esm.widgets.utils.printing import wrapped_print
from esm.widgets.utils.serialization import create_download_results_button
from esm.widgets.utils.serialization import (
create_download_results_button,
)
def create_esm3_generation_launcher(

View File

@@ -1,6 +1,8 @@
from ipywidgets import widgets
from esm.widgets.components.sasa_prompt_selector import create_sasa_prompt_selector
from esm.widgets.components.sasa_prompt_selector import (
create_sasa_prompt_selector,
)
from esm.widgets.components.secondary_structure_prompt_selector import (
create_secondary_structure_prompt_selector,
)

View File

@@ -4,12 +4,20 @@ from ipywidgets import widgets
from esm.sdk.api import ESM3InferenceClient, ESMProtein
from esm.utils.constants import esm3 as C
from esm.widgets.components.function_annotator import create_function_annotator
from esm.widgets.components.function_annotator import (
create_function_annotator,
)
from esm.widgets.utils.prompting import PromptManagerCollection
from esm.widgets.utils.protein_import import ProteinImporter
from esm.widgets.views.esm3_generation_launcher import create_esm3_generation_launcher
from esm.widgets.views.esm3_prompt_preview import create_esm3_prompt_preview
from esm.widgets.views.esm3_prompt_selector import create_esm3_prompt_selector
from esm.widgets.views.esm3_generation_launcher import (
create_esm3_generation_launcher,
)
from esm.widgets.views.esm3_prompt_preview import (
create_esm3_prompt_preview,
)
from esm.widgets.views.esm3_prompt_selector import (
create_esm3_prompt_selector,
)
def create_generation_ui(

View File

@@ -6,7 +6,9 @@ from esm.sdk.api import (
ESMProteinError,
GenerationConfig,
)
from esm.widgets.components.results_visualizer import create_results_visualizer
from esm.widgets.components.results_visualizer import (
create_results_visualizer,
)
from esm.widgets.utils.printing import wrapped_print
from esm.widgets.utils.protein_import import ProteinImporter

View File

@@ -4,7 +4,10 @@ from textwrap import dedent
from ipywidgets import widgets
from esm.widgets.utils.clients import get_forge_client, get_local_client
from esm.widgets.utils.clients import (
get_forge_client,
get_local_client,
)
from esm.widgets.utils.types import ClientInitContainer

View File

@@ -6,7 +6,9 @@ from esm.sdk.api import (
ESMProteinError,
GenerationConfig,
)
from esm.widgets.components.results_visualizer import create_results_visualizer
from esm.widgets.components.results_visualizer import (
create_results_visualizer,
)
from esm.widgets.utils.printing import wrapped_print
from esm.widgets.utils.protein_import import ProteinImporter

View File

@@ -2,7 +2,6 @@ import os
import pytest
import torch
from esm.sdk import client # pyright: ignore
from esm.sdk.api import ( # pyright: ignore
ESMProtein,