mirror of
https://github.com/evolutionaryscale/esm.git
synced 2026-06-04 17:14:23 +08:00
Sync
This commit is contained in:
@@ -38,7 +38,6 @@
|
||||
"\n",
|
||||
"!pip install py3Dmol\n",
|
||||
"import py3Dmol\n",
|
||||
"\n",
|
||||
"from esm.models.esm3 import ESM3\n",
|
||||
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
|
||||
"from esm.utils.structure.protein_chain import ProteinChain"
|
||||
|
||||
@@ -2,7 +2,6 @@ import random
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from esm.pretrained import (
|
||||
ESM3_function_decoder_v0,
|
||||
ESM3_sm_open_v0,
|
||||
@@ -13,7 +12,9 @@ from esm.tokenization import get_esm3_model_tokenizers
|
||||
from esm.tokenization.function_tokenizer import (
|
||||
InterProQuantizedTokenizer as EsmFunctionTokenizer,
|
||||
)
|
||||
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
|
||||
from esm.tokenization.sequence_tokenizer import (
|
||||
EsmSequenceTokenizer,
|
||||
)
|
||||
from esm.utils.structure.protein_chain import ProteinChain
|
||||
from esm.utils.types import FunctionAnnotation
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ import os
|
||||
from typing import cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
from esm.sdk.api import (
|
||||
ESM3InferenceClient,
|
||||
ESMProtein,
|
||||
|
||||
@@ -72,7 +72,6 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from biotite.database import rcsb\n",
|
||||
"\n",
|
||||
"from esm.sdk.api import ESMProtein\n",
|
||||
"from esm.utils.structure.protein_chain import ProteinChain\n",
|
||||
"from esm.utils.types import FunctionAnnotation\n",
|
||||
@@ -497,9 +496,8 @@
|
||||
"# Functions for visualizing InterPro function annotations\n",
|
||||
"\n",
|
||||
"from dna_features_viewer import GraphicFeature, GraphicRecord\n",
|
||||
"from matplotlib import colormaps\n",
|
||||
"\n",
|
||||
"from esm.utils.function.interpro import InterPro, InterProEntryType\n",
|
||||
"from matplotlib import colormaps\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def visualize_function_annotations(\n",
|
||||
|
||||
@@ -64,7 +64,6 @@
|
||||
"import matplotlib.pyplot as pl\n",
|
||||
"import py3Dmol\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"from esm.sdk import client\n",
|
||||
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
|
||||
"from esm.utils.structure.protein_chain import ProteinChain"
|
||||
|
||||
@@ -36,7 +36,6 @@
|
||||
"\n",
|
||||
"!pip install py3Dmol\n",
|
||||
"import py3Dmol\n",
|
||||
"\n",
|
||||
"from esm.sdk import client\n",
|
||||
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
|
||||
"from esm.utils.structure.protein_chain import ProteinChain"
|
||||
|
||||
@@ -49,7 +49,6 @@
|
||||
"source": [
|
||||
"import biotite.structure as bs\n",
|
||||
"import py3Dmol\n",
|
||||
"\n",
|
||||
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
|
||||
"from esm.sdk.experimental import ESM3GuidedDecoding, GuidedDecodingScoringFunction"
|
||||
]
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
__version__ = "3.2.2.post2"
|
||||
|
||||
|
||||
@@ -5,7 +5,10 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from esm.layers.rotary import RotaryEmbedding, TritonRotaryEmbedding
|
||||
from esm.layers.rotary import (
|
||||
RotaryEmbedding,
|
||||
TritonRotaryEmbedding,
|
||||
)
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_qkvpacked_func # type: ignore
|
||||
|
||||
@@ -2,8 +2,13 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from esm.layers.attention import FlashMultiHeadAttention, MultiHeadAttention
|
||||
from esm.layers.geom_attention import GeometricReasoningOriginalImpl
|
||||
from esm.layers.attention import (
|
||||
FlashMultiHeadAttention,
|
||||
MultiHeadAttention,
|
||||
)
|
||||
from esm.layers.geom_attention import (
|
||||
GeometricReasoningOriginalImpl,
|
||||
)
|
||||
from esm.utils.structure.affine3d import Affine3D
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,10 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from esm.utils.constants.physics import BB_COORDINATES
|
||||
from esm.utils.structure.affine3d import Affine3D, RotationMatrix
|
||||
from esm.utils.structure.affine3d import (
|
||||
Affine3D,
|
||||
RotationMatrix,
|
||||
)
|
||||
|
||||
|
||||
class Dim6RotStructureHead(nn.Module):
|
||||
|
||||
@@ -13,7 +13,10 @@ from attr import dataclass
|
||||
from esm.layers.regression_head import RegressionHead
|
||||
from esm.layers.transformer_stack import TransformerStack
|
||||
from esm.models.function_decoder import FunctionTokenDecoder
|
||||
from esm.models.vqvae import StructureTokenDecoder, StructureTokenEncoder
|
||||
from esm.models.vqvae import (
|
||||
StructureTokenDecoder,
|
||||
StructureTokenEncoder,
|
||||
)
|
||||
from esm.sdk.api import (
|
||||
ESM3InferenceClient,
|
||||
ESMProtein,
|
||||
@@ -29,7 +32,10 @@ from esm.sdk.api import (
|
||||
from esm.tokenization import TokenizerCollectionProtocol
|
||||
from esm.utils import encoding
|
||||
from esm.utils.constants import esm3 as C
|
||||
from esm.utils.constants.models import ESM3_OPEN_SMALL, normalize_model_name
|
||||
from esm.utils.constants.models import (
|
||||
ESM3_OPEN_SMALL,
|
||||
normalize_model_name,
|
||||
)
|
||||
from esm.utils.decoding import decode_protein_tensor
|
||||
from esm.utils.generation import (
|
||||
_batch_forward,
|
||||
@@ -44,7 +50,9 @@ from esm.utils.sampling import (
|
||||
get_default_sampling_config,
|
||||
validate_sampling_config,
|
||||
)
|
||||
from esm.utils.structure.affine3d import build_affine3d_from_coordinates
|
||||
from esm.utils.structure.affine3d import (
|
||||
build_affine3d_from_coordinates,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -12,7 +12,9 @@ from cloudpathlib import AnyPath
|
||||
|
||||
from esm.layers.regression_head import RegressionHead
|
||||
from esm.layers.transformer_stack import TransformerStack
|
||||
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
|
||||
from esm.tokenization.function_tokenizer import (
|
||||
InterProQuantizedTokenizer,
|
||||
)
|
||||
from esm.utils.constants import esm3 as C
|
||||
from esm.utils.misc import merge_annotations, merge_ranges
|
||||
from esm.utils.types import FunctionAnnotation
|
||||
|
||||
@@ -7,7 +7,10 @@ from esm.layers.structure_proj import Dim6RotStructureHead
|
||||
from esm.layers.transformer_stack import TransformerStack
|
||||
from esm.utils.constants import esm3 as C
|
||||
from esm.utils.misc import knn_graph
|
||||
from esm.utils.structure.affine3d import Affine3D, build_affine3d_from_coordinates
|
||||
from esm.utils.structure.affine3d import (
|
||||
Affine3D,
|
||||
build_affine3d_from_coordinates,
|
||||
)
|
||||
from esm.utils.structure.predicted_aligned_error import (
|
||||
compute_predicted_aligned_error,
|
||||
compute_tm,
|
||||
|
||||
@@ -6,8 +6,14 @@ import torch.nn as nn
|
||||
from esm.models.esm3 import ESM3
|
||||
from esm.models.esmc import ESMC
|
||||
from esm.models.function_decoder import FunctionTokenDecoder
|
||||
from esm.models.vqvae import StructureTokenDecoder, StructureTokenEncoder
|
||||
from esm.tokenization import get_esm3_model_tokenizers, get_esmc_model_tokenizers
|
||||
from esm.models.vqvae import (
|
||||
StructureTokenDecoder,
|
||||
StructureTokenEncoder,
|
||||
)
|
||||
from esm.tokenization import (
|
||||
get_esm3_model_tokenizers,
|
||||
get_esmc_model_tokenizers,
|
||||
)
|
||||
from esm.utils.constants.esm3 import data_root
|
||||
from esm.utils.constants.models import (
|
||||
ESM3_FUNCTION_DECODER_V0,
|
||||
|
||||
@@ -2,19 +2,27 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from copy import deepcopy
|
||||
from typing import Sequence
|
||||
from typing import List, Sequence
|
||||
|
||||
import attr
|
||||
import torch
|
||||
from attr import asdict, define
|
||||
|
||||
import esm.utils.constants.api as C
|
||||
from esm.tokenization import TokenizerCollectionProtocol, get_esm3_model_tokenizers
|
||||
from esm.tokenization import (
|
||||
TokenizerCollectionProtocol,
|
||||
get_esm3_model_tokenizers,
|
||||
)
|
||||
from esm.utils import encoding
|
||||
from esm.utils.constants.models import ESM3_OPEN_SMALL
|
||||
from esm.utils.misc import get_chainbreak_boundaries_from_sequence
|
||||
from esm.utils.misc import (
|
||||
get_chainbreak_boundaries_from_sequence,
|
||||
)
|
||||
from esm.utils.structure.protein_chain import ProteinChain
|
||||
from esm.utils.structure.protein_complex import SINGLE_LETTER_CHAIN_IDS, ProteinComplex
|
||||
from esm.utils.structure.protein_complex import (
|
||||
SINGLE_LETTER_CHAIN_IDS,
|
||||
ProteinComplex,
|
||||
)
|
||||
from esm.utils.types import FunctionAnnotation, PathOrBuffer
|
||||
|
||||
|
||||
@@ -35,6 +43,7 @@ class ESMProtein(ProteinType):
|
||||
plddt: torch.Tensor | None = None
|
||||
ptm: torch.Tensor | None = None
|
||||
|
||||
|
||||
# When calling EvolutionaryScale API, use this flag to disclose any
|
||||
# sequences that may potentially have concerns.
|
||||
# Such sequences may not go through standard safety filter for approved users.
|
||||
@@ -327,6 +336,8 @@ class InverseFoldingConfig:
|
||||
temperature: float = 1.0
|
||||
|
||||
|
||||
|
||||
|
||||
## Low Level Endpoint Types
|
||||
@define
|
||||
class SamplingTrackConfig:
|
||||
@@ -391,6 +402,9 @@ class LogitsConfig:
|
||||
ith_hidden_layer: int = -1
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@define
|
||||
class LogitsOutput:
|
||||
logits: ForwardTrackData | None = None
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
import asyncio
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
|
||||
from esm.sdk.api import ESMProteinError
|
||||
from esm.sdk.retry import retry_decorator
|
||||
from esm.utils.decoding import assemble_message
|
||||
|
||||
|
||||
@@ -112,7 +116,10 @@ class _BaseForgeInferenceClient:
|
||||
):
|
||||
try:
|
||||
request, headers = self.prepare_request(
|
||||
request, potential_sequence_of_concern, return_bytes, headers
|
||||
request,
|
||||
potential_sequence_of_concern,
|
||||
return_bytes,
|
||||
headers,
|
||||
)
|
||||
response = await self.async_client.post(
|
||||
url=urljoin(self.url, f"/api/v1/{endpoint}"),
|
||||
@@ -142,7 +149,10 @@ class _BaseForgeInferenceClient:
|
||||
):
|
||||
try:
|
||||
request, headers = self.prepare_request(
|
||||
request, potential_sequence_of_concern, return_bytes, headers
|
||||
request,
|
||||
potential_sequence_of_concern,
|
||||
return_bytes,
|
||||
headers,
|
||||
)
|
||||
response = self.client.post(
|
||||
url=urljoin(self.url, f"/api/v1/{endpoint}"),
|
||||
@@ -160,3 +170,5 @@ class _BaseForgeInferenceClient:
|
||||
error_code=500,
|
||||
error_msg=f"Failed to submit request to {endpoint}. Error: {e}",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import asyncio
|
||||
import base64
|
||||
import pickle
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Sequence
|
||||
from typing import Any, Literal, Sequence, cast
|
||||
|
||||
import torch
|
||||
|
||||
@@ -20,14 +20,21 @@ from esm.sdk.api import (
|
||||
InverseFoldingConfig,
|
||||
LogitsConfig,
|
||||
LogitsOutput,
|
||||
ProteinChain,
|
||||
ProteinType,
|
||||
SamplingConfig,
|
||||
SamplingTrackConfig,
|
||||
)
|
||||
from esm.sdk.base_forge_client import _BaseForgeInferenceClient
|
||||
from esm.sdk.base_forge_client import (
|
||||
_BaseForgeInferenceClient,
|
||||
)
|
||||
from esm.sdk.retry import retry_decorator
|
||||
from esm.utils.constants.api import MIMETYPE_ES_PICKLE
|
||||
from esm.utils.misc import deserialize_tensors, maybe_list, maybe_tensor
|
||||
from esm.utils.misc import (
|
||||
deserialize_tensors,
|
||||
maybe_list,
|
||||
maybe_tensor,
|
||||
)
|
||||
from esm.utils.msa import MSA
|
||||
from esm.utils.structure.input_builder import (
|
||||
StructurePredictionInput,
|
||||
@@ -101,9 +108,13 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _process_fold_request(sequence: str, model_name: str | None):
|
||||
def _process_fold_request(
|
||||
sequence: str,
|
||||
model_name: str | None,
|
||||
):
|
||||
request: dict[str, Any] = {"sequence": sequence}
|
||||
|
||||
|
||||
request["model"] = model_name
|
||||
|
||||
return request
|
||||
@@ -138,6 +149,7 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
||||
|
||||
return request
|
||||
|
||||
|
||||
async def _async_fetch_msa(self, sequence: str) -> MSA:
|
||||
print("Fetching MSA ... this may take a few minutes")
|
||||
# Accept both "|" and ":" as the chainbreak token.
|
||||
@@ -176,11 +188,15 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
||||
del potential_sequence_of_concern
|
||||
|
||||
request = self._process_fold_request(
|
||||
sequence, model_name if model_name is not None else self.model
|
||||
sequence,
|
||||
model_name if model_name is not None else self.model,
|
||||
)
|
||||
|
||||
try:
|
||||
data = await self._async_post("fold", request)
|
||||
data = await self._async_post(
|
||||
"fold",
|
||||
request,
|
||||
)
|
||||
except ESMProteinError as e:
|
||||
return e
|
||||
|
||||
@@ -207,11 +223,15 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
||||
del potential_sequence_of_concern
|
||||
|
||||
request = self._process_fold_request(
|
||||
sequence, model_name if model_name is not None else self.model
|
||||
sequence,
|
||||
model_name if model_name is not None else self.model,
|
||||
)
|
||||
|
||||
try:
|
||||
data = self._post("fold", request)
|
||||
data = self._post(
|
||||
"fold",
|
||||
request,
|
||||
)
|
||||
except ESMProteinError as e:
|
||||
return e
|
||||
|
||||
@@ -219,7 +239,9 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
||||
|
||||
@retry_decorator
|
||||
async def async_fold_all_atom(
|
||||
self, all_atom_input: StructurePredictionInput, model_name: str | None = None
|
||||
self,
|
||||
all_atom_input: StructurePredictionInput,
|
||||
model_name: str | None = None,
|
||||
) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError:
|
||||
"""Fold a molecular complex containing proteins, nucleic acids, and/or ligands.
|
||||
|
||||
@@ -228,11 +250,15 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
||||
model_name: Override the client level model name if needed
|
||||
"""
|
||||
request = self._process_fold_all_atom_request(
|
||||
all_atom_input, model_name if model_name is not None else self.model
|
||||
all_atom_input,
|
||||
model_name if model_name is not None else self.model,
|
||||
)
|
||||
|
||||
try:
|
||||
data = await self._async_post("fold_all_atom", request)
|
||||
data = await self._async_post(
|
||||
"fold_all_atom",
|
||||
request,
|
||||
)
|
||||
except ESMProteinError as e:
|
||||
return e
|
||||
|
||||
@@ -240,7 +266,9 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
||||
|
||||
@retry_decorator
|
||||
def fold_all_atom(
|
||||
self, all_atom_input: StructurePredictionInput, model_name: str | None = None
|
||||
self,
|
||||
all_atom_input: StructurePredictionInput,
|
||||
model_name: str | None = None,
|
||||
) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError:
|
||||
"""Predict coordinates for a molecular complex containing proteins, dna, rna, and/or ligands.
|
||||
|
||||
@@ -249,11 +277,15 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
||||
model_name: Override the client level model name if needed
|
||||
"""
|
||||
request = self._process_fold_all_atom_request(
|
||||
all_atom_input, model_name if model_name is not None else self.model
|
||||
all_atom_input,
|
||||
model_name if model_name is not None else self.model,
|
||||
)
|
||||
|
||||
try:
|
||||
data = self._post("fold_all_atom", request)
|
||||
data = self._post(
|
||||
"fold_all_atom",
|
||||
request,
|
||||
)
|
||||
except ESMProteinError as e:
|
||||
return e
|
||||
|
||||
@@ -261,13 +293,15 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
||||
|
||||
@staticmethod
|
||||
def _process_fold_all_atom_request(
|
||||
all_atom_input: StructurePredictionInput, model_name: str | None = None
|
||||
all_atom_input: StructurePredictionInput,
|
||||
model_name: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
request: dict[str, Any] = {
|
||||
"all_atom_input": serialize_structure_prediction_input(all_atom_input),
|
||||
"model": model_name,
|
||||
}
|
||||
|
||||
|
||||
return request
|
||||
|
||||
@staticmethod
|
||||
@@ -352,6 +386,7 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
||||
return ESMProtein(sequence=data["sequence"])
|
||||
|
||||
|
||||
|
||||
class ESM3ForgeInferenceClient(ESM3InferenceClient, _BaseForgeInferenceClient):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1177,3 +1212,5 @@ class ESMCForgeInferenceClient(ESMCInferenceClient, _BaseForgeInferenceClient):
|
||||
raise NotImplementedError(
|
||||
f"Can not get underlying remote model {self.model} from a Forge client."
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol
|
||||
|
||||
from esm.utils.constants.models import ESM3_OPEN_SMALL, normalize_model_name
|
||||
from esm.utils.constants.models import (
|
||||
ESM3_OPEN_SMALL,
|
||||
normalize_model_name,
|
||||
)
|
||||
|
||||
from .function_tokenizer import InterProQuantizedTokenizer
|
||||
from .residue_tokenizer import ResidueAnnotationsTokenizer
|
||||
|
||||
@@ -10,12 +10,24 @@ from esm.models.function_decoder import FunctionTokenDecoder
|
||||
from esm.models.vqvae import StructureTokenDecoder
|
||||
from esm.sdk.api import ESMProtein, ESMProteinTensor
|
||||
from esm.tokenization import TokenizerCollectionProtocol
|
||||
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
|
||||
from esm.tokenization.residue_tokenizer import ResidueAnnotationsTokenizer
|
||||
from esm.tokenization.sasa_tokenizer import SASADiscretizingTokenizer
|
||||
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
|
||||
from esm.tokenization.ss_tokenizer import SecondaryStructureTokenizer
|
||||
from esm.tokenization.structure_tokenizer import StructureTokenizer
|
||||
from esm.tokenization.function_tokenizer import (
|
||||
InterProQuantizedTokenizer,
|
||||
)
|
||||
from esm.tokenization.residue_tokenizer import (
|
||||
ResidueAnnotationsTokenizer,
|
||||
)
|
||||
from esm.tokenization.sasa_tokenizer import (
|
||||
SASADiscretizingTokenizer,
|
||||
)
|
||||
from esm.tokenization.sequence_tokenizer import (
|
||||
EsmSequenceTokenizer,
|
||||
)
|
||||
from esm.tokenization.ss_tokenizer import (
|
||||
SecondaryStructureTokenizer,
|
||||
)
|
||||
from esm.tokenization.structure_tokenizer import (
|
||||
StructureTokenizer,
|
||||
)
|
||||
from esm.tokenization.tokenizer_base import EsmTokenizerBase
|
||||
from esm.utils.constants import api as api_constants
|
||||
from esm.utils.constants import esm3 as C
|
||||
|
||||
@@ -7,13 +7,26 @@ from esm.models.vqvae import StructureTokenEncoder
|
||||
from esm.tokenization.function_tokenizer import (
|
||||
InterProQuantizedTokenizer as EsmFunctionTokenizer,
|
||||
)
|
||||
from esm.tokenization.residue_tokenizer import ResidueAnnotationsTokenizer
|
||||
from esm.tokenization.sasa_tokenizer import SASADiscretizingTokenizer
|
||||
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
|
||||
from esm.tokenization.ss_tokenizer import SecondaryStructureTokenizer
|
||||
from esm.tokenization.structure_tokenizer import StructureTokenizer
|
||||
|
||||
from esm.tokenization.residue_tokenizer import (
|
||||
ResidueAnnotationsTokenizer,
|
||||
)
|
||||
from esm.tokenization.sasa_tokenizer import (
|
||||
SASADiscretizingTokenizer,
|
||||
)
|
||||
from esm.tokenization.sequence_tokenizer import (
|
||||
EsmSequenceTokenizer,
|
||||
)
|
||||
from esm.tokenization.ss_tokenizer import (
|
||||
SecondaryStructureTokenizer,
|
||||
)
|
||||
from esm.tokenization.structure_tokenizer import (
|
||||
StructureTokenizer,
|
||||
)
|
||||
from esm.utils.constants import esm3 as C
|
||||
from esm.utils.function.encode_decode import encode_function_annotations
|
||||
from esm.utils.function.encode_decode import (
|
||||
encode_function_annotations,
|
||||
)
|
||||
from esm.utils.structure.protein_chain import ProteinChain
|
||||
from esm.utils.types import FunctionAnnotation
|
||||
|
||||
@@ -152,6 +165,8 @@ def tokenize_function_annotations(
|
||||
return function_tokens, residue_annotation_tokens
|
||||
|
||||
|
||||
|
||||
|
||||
# Tokenized Defaults
|
||||
def get_default_sequence_tokens(
|
||||
sequence_length: int, sequence_tokenizer: EsmSequenceTokenizer
|
||||
@@ -227,3 +242,5 @@ def get_default_residue_annotation_tokens(
|
||||
residue_annotation_tokens[0] = residue_annotation_tokenizer.bos_token_id
|
||||
residue_annotation_tokens[-1] = residue_annotation_tokenizer.eos_token_id
|
||||
return residue_annotation_tokens
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,10 @@ from typing import Any, Callable, List
|
||||
from tqdm import tqdm
|
||||
|
||||
from esm.sdk.api import ESMProteinError
|
||||
from esm.sdk.retry import retry_if_specific_error, skip_retries_var
|
||||
from esm.sdk.retry import (
|
||||
retry_if_specific_error,
|
||||
skip_retries_var,
|
||||
)
|
||||
|
||||
TQDM_BAR_FORMAT = (
|
||||
"{desc:<12}{percentage:3.0f}%|{bar:24}| {n_fmt}/{total_fmt} "
|
||||
|
||||
@@ -3,9 +3,16 @@ from typing import Sequence
|
||||
|
||||
import torch
|
||||
|
||||
from esm.models.function_decoder import FunctionTokenDecoder, merge_annotations
|
||||
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
|
||||
from esm.tokenization.residue_tokenizer import ResidueAnnotationsTokenizer
|
||||
from esm.models.function_decoder import (
|
||||
FunctionTokenDecoder,
|
||||
merge_annotations,
|
||||
)
|
||||
from esm.tokenization.function_tokenizer import (
|
||||
InterProQuantizedTokenizer,
|
||||
)
|
||||
from esm.tokenization.residue_tokenizer import (
|
||||
ResidueAnnotationsTokenizer,
|
||||
)
|
||||
from esm.utils.constants import esm3 as C
|
||||
from esm.utils.types import FunctionAnnotation
|
||||
|
||||
|
||||
@@ -19,8 +19,13 @@ from esm.sdk.api import (
|
||||
SamplingConfig,
|
||||
SamplingTrackConfig,
|
||||
)
|
||||
from esm.tokenization import EsmTokenizerBase, TokenizerCollectionProtocol
|
||||
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
|
||||
from esm.tokenization import (
|
||||
EsmTokenizerBase,
|
||||
TokenizerCollectionProtocol,
|
||||
)
|
||||
from esm.tokenization.function_tokenizer import (
|
||||
InterProQuantizedTokenizer,
|
||||
)
|
||||
from esm.utils.constants import esm3 as C
|
||||
from esm.utils.misc import stack_variable_length_tensors
|
||||
from esm.utils.noise_schedules import NOISE_SCHEDULE_REGISTRY
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import is_dataclass
|
||||
from io import BytesIO
|
||||
from typing import (
|
||||
@@ -261,7 +262,7 @@ def unbinpack(
|
||||
return stack_variable_length_tensors(unpacked_tensors, pad_value)
|
||||
|
||||
|
||||
def fp32_autocast_context(device_type: str) -> ContextManager[torch.amp.autocast]: # type: ignore
|
||||
def fp32_autocast_context(device_type: str) -> ContextManager[Any]: # type: ignore
|
||||
"""
|
||||
Returns an autocast context manager that disables downcasting by AMP.
|
||||
|
||||
@@ -273,6 +274,9 @@ def fp32_autocast_context(device_type: str) -> ContextManager[torch.amp.autocast
|
||||
"""
|
||||
if device_type == "cpu":
|
||||
return torch.amp.autocast(device_type, enabled=False) # type: ignore
|
||||
elif device_type == "mps":
|
||||
# For MPS, just return a no-op context manager (nullcontext) since MPS does not support autocast.
|
||||
return nullcontext()
|
||||
elif device_type == "cuda":
|
||||
return torch.amp.autocast(device_type, dtype=torch.float32) # type: ignore
|
||||
else:
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
from esm.utils.msa.msa import MSA, FastMSA, remove_insertions_from_sequence
|
||||
from esm.utils.msa.msa import (
|
||||
MSA,
|
||||
FastMSA,
|
||||
remove_insertions_from_sequence,
|
||||
)
|
||||
|
||||
__all__ = ["MSA", "FastMSA", "remove_insertions_from_sequence"]
|
||||
|
||||
@@ -12,8 +12,15 @@ from Bio import SeqIO
|
||||
from scipy.spatial.distance import cdist
|
||||
|
||||
from esm.utils.misc import slice_any_object
|
||||
from esm.utils.msa.filter_sequences import greedy_select_indices, hhfilter
|
||||
from esm.utils.parsing import FastaEntry, read_sequences, write_sequences
|
||||
from esm.utils.msa.filter_sequences import (
|
||||
greedy_select_indices,
|
||||
hhfilter,
|
||||
)
|
||||
from esm.utils.parsing import (
|
||||
FastaEntry,
|
||||
read_sequences,
|
||||
write_sequences,
|
||||
)
|
||||
from esm.utils.sequential_dataclass import SequentialDataclass
|
||||
from esm.utils.system import PathOrBuffer
|
||||
|
||||
|
||||
@@ -5,9 +5,18 @@ import attr
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from esm.sdk.api import ESMProteinTensor, SamplingConfig, SamplingTrackConfig
|
||||
from esm.tokenization import TokenizerCollectionProtocol, get_invalid_tokenizer_ids
|
||||
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
|
||||
from esm.sdk.api import (
|
||||
ESMProteinTensor,
|
||||
SamplingConfig,
|
||||
SamplingTrackConfig,
|
||||
)
|
||||
from esm.tokenization import (
|
||||
TokenizerCollectionProtocol,
|
||||
get_invalid_tokenizer_ids,
|
||||
)
|
||||
from esm.tokenization.function_tokenizer import (
|
||||
InterProQuantizedTokenizer,
|
||||
)
|
||||
from esm.utils.constants.esm3 import (
|
||||
MAX_RESIDUE_ANNOTATIONS,
|
||||
SASA_DISCRETIZATION_BOUNDARIES,
|
||||
|
||||
@@ -6,7 +6,9 @@ from typing import Any, ClassVar, Protocol, TypeVar
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from esm.utils.structure.protein_structure import compute_affine_and_rmsd
|
||||
from esm.utils.structure.protein_structure import (
|
||||
compute_affine_and_rmsd,
|
||||
)
|
||||
|
||||
|
||||
class Alignable(Protocol):
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import numpy as np
|
||||
|
||||
from esm.utils.structure.protein_structure import index_by_atom_name
|
||||
from esm.utils.structure.protein_structure import (
|
||||
index_by_atom_name,
|
||||
)
|
||||
|
||||
|
||||
class AtomIndexer:
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Sequence
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class Modification:
|
||||
position: int # zero-indexed
|
||||
|
||||
@@ -16,8 +16,14 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from esm.utils import residue_constants
|
||||
from esm.utils.structure.metrics import compute_lddt, compute_rmsd
|
||||
from esm.utils.structure.protein_complex import ProteinComplex, ProteinComplexMetadata
|
||||
from esm.utils.structure.metrics import (
|
||||
compute_lddt,
|
||||
compute_rmsd,
|
||||
)
|
||||
from esm.utils.structure.protein_complex import (
|
||||
ProteinComplex,
|
||||
ProteinComplexMetadata,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -25,13 +25,21 @@ from esm.utils.misc import slice_python_object_as_numpy
|
||||
from esm.utils.structure.affine3d import Affine3D
|
||||
from esm.utils.structure.aligner import Aligner
|
||||
from esm.utils.structure.atom_indexer import AtomIndexer
|
||||
from esm.utils.structure.metrics import compute_gdt_ts, compute_lddt_ca
|
||||
from esm.utils.structure.mmcif_parsing import MmcifWrapper, Residue
|
||||
from esm.utils.structure.metrics import (
|
||||
compute_gdt_ts,
|
||||
compute_lddt_ca,
|
||||
)
|
||||
from esm.utils.structure.mmcif_parsing import (
|
||||
MmcifWrapper,
|
||||
Residue,
|
||||
)
|
||||
from esm.utils.structure.normalize_coordinates import (
|
||||
apply_frame_to_coords,
|
||||
get_protein_normalization_frame,
|
||||
)
|
||||
from esm.utils.structure.protein_structure import index_by_atom_name
|
||||
from esm.utils.structure.protein_structure import (
|
||||
index_by_atom_name,
|
||||
)
|
||||
from esm.utils.types import PathOrBuffer
|
||||
|
||||
msgpack_numpy.patch()
|
||||
@@ -393,6 +401,7 @@ class ProteinChain:
|
||||
bytes = input
|
||||
return cls.from_state_dict(msgpack.loads(brotli.decompress(bytes)))
|
||||
|
||||
|
||||
def sasa(self, by_residue: bool = True):
|
||||
arr = self.atom_array_no_insertions
|
||||
sasa_per_atom = bs.sasa(arr) # type: ignore
|
||||
@@ -698,6 +707,7 @@ class ProteinChain:
|
||||
)
|
||||
return float(gdt_ts) if gdt_ts.numel() == 1 else gdt_ts.numpy().flatten()
|
||||
|
||||
|
||||
@classmethod
|
||||
def chain_iterable_from_mmcif(
|
||||
cls,
|
||||
|
||||
@@ -32,8 +32,14 @@ from esm.utils.misc import slice_python_object_as_numpy
|
||||
from esm.utils.structure.affine3d import Affine3D
|
||||
from esm.utils.structure.aligner import Aligner
|
||||
from esm.utils.structure.atom_indexer import AtomIndexer
|
||||
from esm.utils.structure.metrics import compute_gdt_ts, compute_lddt_ca
|
||||
from esm.utils.structure.mmcif_parsing import MmcifWrapper, NoProteinError
|
||||
from esm.utils.structure.metrics import (
|
||||
compute_gdt_ts,
|
||||
compute_lddt_ca,
|
||||
)
|
||||
from esm.utils.structure.mmcif_parsing import (
|
||||
MmcifWrapper,
|
||||
NoProteinError,
|
||||
)
|
||||
from esm.utils.structure.protein_chain import (
|
||||
ProteinChain,
|
||||
chain_to_ndarray,
|
||||
|
||||
@@ -4,7 +4,9 @@ import pygtrie
|
||||
from ipywidgets import widgets
|
||||
|
||||
from esm.sdk.api import FunctionAnnotation
|
||||
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
|
||||
from esm.tokenization.function_tokenizer import (
|
||||
InterProQuantizedTokenizer,
|
||||
)
|
||||
|
||||
TRIE: pygtrie.CharTrie | None = None
|
||||
|
||||
|
||||
@@ -7,11 +7,15 @@ import matplotlib.colors as mcolors
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from esm.sdk.api import ESMProtein
|
||||
from esm.widgets.utils.drawing.draw_category_array import draw_data_array
|
||||
from esm.widgets.utils.drawing.draw_category_array import (
|
||||
draw_data_array,
|
||||
)
|
||||
from esm.widgets.utils.drawing.draw_function_annotations import (
|
||||
draw_function_annotations,
|
||||
)
|
||||
from esm.widgets.utils.drawing.draw_protein_structure import draw_protein_structure
|
||||
from esm.widgets.utils.drawing.draw_protein_structure import (
|
||||
draw_protein_structure,
|
||||
)
|
||||
from esm.widgets.utils.serialization import (
|
||||
create_download_button_from_buffer,
|
||||
protein_to_pdb_buffer,
|
||||
|
||||
@@ -3,9 +3,16 @@ from typing import Any, Callable, Sequence
|
||||
import ipywidgets as widgets
|
||||
|
||||
from esm.utils.structure.protein_chain import ProteinChain
|
||||
from esm.widgets.utils.drawing.colors import hex_to_rgba_tuple, rgba_tuple_to_hex
|
||||
from esm.widgets.utils.drawing.draw_category_array import draw_data_array
|
||||
from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges
|
||||
from esm.widgets.utils.drawing.colors import (
|
||||
hex_to_rgba_tuple,
|
||||
rgba_tuple_to_hex,
|
||||
)
|
||||
from esm.widgets.utils.drawing.draw_category_array import (
|
||||
draw_data_array,
|
||||
)
|
||||
from esm.widgets.utils.parsing import (
|
||||
convert_range_string_to_list_of_ranges,
|
||||
)
|
||||
from esm.widgets.utils.prompting import PromptManager
|
||||
|
||||
|
||||
|
||||
@@ -4,9 +4,16 @@ import ipywidgets as widgets
|
||||
import pydssp
|
||||
|
||||
from esm.utils.structure.protein_chain import ProteinChain
|
||||
from esm.widgets.utils.drawing.colors import hex_to_rgba_tuple, rgba_tuple_to_hex
|
||||
from esm.widgets.utils.drawing.draw_category_array import draw_data_array
|
||||
from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges
|
||||
from esm.widgets.utils.drawing.colors import (
|
||||
hex_to_rgba_tuple,
|
||||
rgba_tuple_to_hex,
|
||||
)
|
||||
from esm.widgets.utils.drawing.draw_category_array import (
|
||||
draw_data_array,
|
||||
)
|
||||
from esm.widgets.utils.parsing import (
|
||||
convert_range_string_to_list_of_ranges,
|
||||
)
|
||||
from esm.widgets.utils.prompting import PromptManager
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,9 @@ from esm.widgets.utils.drawing.colors import (
|
||||
hex_to_rgba_tuple,
|
||||
rgba_tuple_to_rgba_html_string,
|
||||
)
|
||||
from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges
|
||||
from esm.widgets.utils.parsing import (
|
||||
convert_range_string_to_list_of_ranges,
|
||||
)
|
||||
from esm.widgets.utils.prompting import PromptManager
|
||||
|
||||
|
||||
|
||||
@@ -10,8 +10,12 @@ from matplotlib.patches import Rectangle
|
||||
|
||||
from esm.utils.structure.protein_chain import ProteinChain
|
||||
from esm.widgets.utils import indexing
|
||||
from esm.widgets.utils.drawing.draw_protein_structure import draw_protein_structure
|
||||
from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges
|
||||
from esm.widgets.utils.drawing.draw_protein_structure import (
|
||||
draw_protein_structure,
|
||||
)
|
||||
from esm.widgets.utils.parsing import (
|
||||
convert_range_string_to_list_of_ranges,
|
||||
)
|
||||
from esm.widgets.utils.printing import wrapped_print
|
||||
from esm.widgets.utils.prompting import PromptManager
|
||||
|
||||
|
||||
@@ -9,7 +9,10 @@ from matplotlib import colormaps
|
||||
from PIL import Image
|
||||
|
||||
from esm.sdk.api import FunctionAnnotation
|
||||
from esm.utils.function.interpro import InterPro, InterProEntryType
|
||||
from esm.utils.function.interpro import (
|
||||
InterPro,
|
||||
InterProEntryType,
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
||||
@@ -9,7 +9,9 @@ from esm.sdk.api import ESMProtein, FunctionAnnotation
|
||||
from esm.utils import encoding
|
||||
from esm.widgets.utils import indexing
|
||||
from esm.widgets.utils.drawing.colors import rgba_tuple_to_hex
|
||||
from esm.widgets.utils.drawing.draw_category_array import draw_data_array
|
||||
from esm.widgets.utils.drawing.draw_category_array import (
|
||||
draw_data_array,
|
||||
)
|
||||
from esm.widgets.utils.printing import wrapped_print
|
||||
|
||||
|
||||
|
||||
@@ -13,9 +13,13 @@ from esm.sdk.api import (
|
||||
GenerationConfig,
|
||||
)
|
||||
from esm.utils.constants import models
|
||||
from esm.widgets.components.results_visualizer import create_results_visualizer
|
||||
from esm.widgets.components.results_visualizer import (
|
||||
create_results_visualizer,
|
||||
)
|
||||
from esm.widgets.utils.printing import wrapped_print
|
||||
from esm.widgets.utils.serialization import create_download_results_button
|
||||
from esm.widgets.utils.serialization import (
|
||||
create_download_results_button,
|
||||
)
|
||||
|
||||
|
||||
def create_esm3_generation_launcher(
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from ipywidgets import widgets
|
||||
|
||||
from esm.widgets.components.sasa_prompt_selector import create_sasa_prompt_selector
|
||||
from esm.widgets.components.sasa_prompt_selector import (
|
||||
create_sasa_prompt_selector,
|
||||
)
|
||||
from esm.widgets.components.secondary_structure_prompt_selector import (
|
||||
create_secondary_structure_prompt_selector,
|
||||
)
|
||||
|
||||
@@ -4,12 +4,20 @@ from ipywidgets import widgets
|
||||
|
||||
from esm.sdk.api import ESM3InferenceClient, ESMProtein
|
||||
from esm.utils.constants import esm3 as C
|
||||
from esm.widgets.components.function_annotator import create_function_annotator
|
||||
from esm.widgets.components.function_annotator import (
|
||||
create_function_annotator,
|
||||
)
|
||||
from esm.widgets.utils.prompting import PromptManagerCollection
|
||||
from esm.widgets.utils.protein_import import ProteinImporter
|
||||
from esm.widgets.views.esm3_generation_launcher import create_esm3_generation_launcher
|
||||
from esm.widgets.views.esm3_prompt_preview import create_esm3_prompt_preview
|
||||
from esm.widgets.views.esm3_prompt_selector import create_esm3_prompt_selector
|
||||
from esm.widgets.views.esm3_generation_launcher import (
|
||||
create_esm3_generation_launcher,
|
||||
)
|
||||
from esm.widgets.views.esm3_prompt_preview import (
|
||||
create_esm3_prompt_preview,
|
||||
)
|
||||
from esm.widgets.views.esm3_prompt_selector import (
|
||||
create_esm3_prompt_selector,
|
||||
)
|
||||
|
||||
|
||||
def create_generation_ui(
|
||||
|
||||
@@ -6,7 +6,9 @@ from esm.sdk.api import (
|
||||
ESMProteinError,
|
||||
GenerationConfig,
|
||||
)
|
||||
from esm.widgets.components.results_visualizer import create_results_visualizer
|
||||
from esm.widgets.components.results_visualizer import (
|
||||
create_results_visualizer,
|
||||
)
|
||||
from esm.widgets.utils.printing import wrapped_print
|
||||
from esm.widgets.utils.protein_import import ProteinImporter
|
||||
|
||||
|
||||
@@ -4,7 +4,10 @@ from textwrap import dedent
|
||||
|
||||
from ipywidgets import widgets
|
||||
|
||||
from esm.widgets.utils.clients import get_forge_client, get_local_client
|
||||
from esm.widgets.utils.clients import (
|
||||
get_forge_client,
|
||||
get_local_client,
|
||||
)
|
||||
from esm.widgets.utils.types import ClientInitContainer
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,9 @@ from esm.sdk.api import (
|
||||
ESMProteinError,
|
||||
GenerationConfig,
|
||||
)
|
||||
from esm.widgets.components.results_visualizer import create_results_visualizer
|
||||
from esm.widgets.components.results_visualizer import (
|
||||
create_results_visualizer,
|
||||
)
|
||||
from esm.widgets.utils.printing import wrapped_print
|
||||
from esm.widgets.utils.protein_import import ProteinImporter
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from esm.sdk import client # pyright: ignore
|
||||
from esm.sdk.api import ( # pyright: ignore
|
||||
ESMProtein,
|
||||
|
||||
Reference in New Issue
Block a user