mirror of
https://github.com/evolutionaryscale/esm.git
synced 2026-06-04 09:04:23 +08:00
ruff checks
This commit is contained in:
@@ -38,6 +38,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"!pip install py3Dmol\n",
|
"!pip install py3Dmol\n",
|
||||||
"import py3Dmol\n",
|
"import py3Dmol\n",
|
||||||
|
"\n",
|
||||||
"from esm.models.esm3 import ESM3\n",
|
"from esm.models.esm3 import ESM3\n",
|
||||||
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
|
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
|
||||||
"from esm.utils.structure.protein_chain import ProteinChain"
|
"from esm.utils.structure.protein_chain import ProteinChain"
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import random
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from esm.pretrained import (
|
from esm.pretrained import (
|
||||||
ESM3_function_decoder_v0,
|
ESM3_function_decoder_v0,
|
||||||
ESM3_sm_open_v0,
|
ESM3_sm_open_v0,
|
||||||
@@ -12,9 +13,7 @@ from esm.tokenization import get_esm3_model_tokenizers
|
|||||||
from esm.tokenization.function_tokenizer import (
|
from esm.tokenization.function_tokenizer import (
|
||||||
InterProQuantizedTokenizer as EsmFunctionTokenizer,
|
InterProQuantizedTokenizer as EsmFunctionTokenizer,
|
||||||
)
|
)
|
||||||
from esm.tokenization.sequence_tokenizer import (
|
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
|
||||||
EsmSequenceTokenizer,
|
|
||||||
)
|
|
||||||
from esm.utils.structure.protein_chain import ProteinChain
|
from esm.utils.structure.protein_chain import ProteinChain
|
||||||
from esm.utils.types import FunctionAnnotation
|
from esm.utils.types import FunctionAnnotation
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import os
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from esm.sdk.api import (
|
from esm.sdk.api import (
|
||||||
ESM3InferenceClient,
|
ESM3InferenceClient,
|
||||||
ESMProtein,
|
ESMProtein,
|
||||||
|
|||||||
@@ -72,6 +72,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from biotite.database import rcsb\n",
|
"from biotite.database import rcsb\n",
|
||||||
|
"\n",
|
||||||
"from esm.sdk.api import ESMProtein\n",
|
"from esm.sdk.api import ESMProtein\n",
|
||||||
"from esm.utils.structure.protein_chain import ProteinChain\n",
|
"from esm.utils.structure.protein_chain import ProteinChain\n",
|
||||||
"from esm.utils.types import FunctionAnnotation\n",
|
"from esm.utils.types import FunctionAnnotation\n",
|
||||||
@@ -496,9 +497,10 @@
|
|||||||
"# Functions for visualizing InterPro function annotations\n",
|
"# Functions for visualizing InterPro function annotations\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from dna_features_viewer import GraphicFeature, GraphicRecord\n",
|
"from dna_features_viewer import GraphicFeature, GraphicRecord\n",
|
||||||
"from esm.utils.function.interpro import InterPro, InterProEntryType\n",
|
|
||||||
"from matplotlib import colormaps\n",
|
"from matplotlib import colormaps\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"from esm.utils.function.interpro import InterPro, InterProEntryType\n",
|
||||||
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def visualize_function_annotations(\n",
|
"def visualize_function_annotations(\n",
|
||||||
" annotations: list[FunctionAnnotation],\n",
|
" annotations: list[FunctionAnnotation],\n",
|
||||||
|
|||||||
@@ -64,6 +64,7 @@
|
|||||||
"import matplotlib.pyplot as pl\n",
|
"import matplotlib.pyplot as pl\n",
|
||||||
"import py3Dmol\n",
|
"import py3Dmol\n",
|
||||||
"import torch\n",
|
"import torch\n",
|
||||||
|
"\n",
|
||||||
"from esm.sdk import client\n",
|
"from esm.sdk import client\n",
|
||||||
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
|
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
|
||||||
"from esm.utils.structure.protein_chain import ProteinChain"
|
"from esm.utils.structure.protein_chain import ProteinChain"
|
||||||
|
|||||||
@@ -36,6 +36,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"!pip install py3Dmol\n",
|
"!pip install py3Dmol\n",
|
||||||
"import py3Dmol\n",
|
"import py3Dmol\n",
|
||||||
|
"\n",
|
||||||
"from esm.sdk import client\n",
|
"from esm.sdk import client\n",
|
||||||
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
|
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
|
||||||
"from esm.utils.structure.protein_chain import ProteinChain"
|
"from esm.utils.structure.protein_chain import ProteinChain"
|
||||||
|
|||||||
@@ -49,6 +49,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"import biotite.structure as bs\n",
|
"import biotite.structure as bs\n",
|
||||||
"import py3Dmol\n",
|
"import py3Dmol\n",
|
||||||
|
"\n",
|
||||||
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
|
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
|
||||||
"from esm.sdk.experimental import ESM3GuidedDecoding, GuidedDecodingScoringFunction"
|
"from esm.sdk.experimental import ESM3GuidedDecoding, GuidedDecodingScoringFunction"
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,2 +1 @@
|
|||||||
__version__ = "3.2.2.post2"
|
__version__ = "3.2.2.post2"
|
||||||
|
|
||||||
|
|||||||
@@ -5,10 +5,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from esm.layers.rotary import (
|
from esm.layers.rotary import RotaryEmbedding, TritonRotaryEmbedding
|
||||||
RotaryEmbedding,
|
|
||||||
TritonRotaryEmbedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from flash_attn import flash_attn_varlen_qkvpacked_func # type: ignore
|
from flash_attn import flash_attn_varlen_qkvpacked_func # type: ignore
|
||||||
|
|||||||
@@ -2,13 +2,8 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from esm.layers.attention import (
|
from esm.layers.attention import FlashMultiHeadAttention, MultiHeadAttention
|
||||||
FlashMultiHeadAttention,
|
from esm.layers.geom_attention import GeometricReasoningOriginalImpl
|
||||||
MultiHeadAttention,
|
|
||||||
)
|
|
||||||
from esm.layers.geom_attention import (
|
|
||||||
GeometricReasoningOriginalImpl,
|
|
||||||
)
|
|
||||||
from esm.utils.structure.affine3d import Affine3D
|
from esm.utils.structure.affine3d import Affine3D
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,10 +2,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from esm.utils.constants.physics import BB_COORDINATES
|
from esm.utils.constants.physics import BB_COORDINATES
|
||||||
from esm.utils.structure.affine3d import (
|
from esm.utils.structure.affine3d import Affine3D, RotationMatrix
|
||||||
Affine3D,
|
|
||||||
RotationMatrix,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Dim6RotStructureHead(nn.Module):
|
class Dim6RotStructureHead(nn.Module):
|
||||||
|
|||||||
@@ -13,10 +13,7 @@ from attr import dataclass
|
|||||||
from esm.layers.regression_head import RegressionHead
|
from esm.layers.regression_head import RegressionHead
|
||||||
from esm.layers.transformer_stack import TransformerStack
|
from esm.layers.transformer_stack import TransformerStack
|
||||||
from esm.models.function_decoder import FunctionTokenDecoder
|
from esm.models.function_decoder import FunctionTokenDecoder
|
||||||
from esm.models.vqvae import (
|
from esm.models.vqvae import StructureTokenDecoder, StructureTokenEncoder
|
||||||
StructureTokenDecoder,
|
|
||||||
StructureTokenEncoder,
|
|
||||||
)
|
|
||||||
from esm.sdk.api import (
|
from esm.sdk.api import (
|
||||||
ESM3InferenceClient,
|
ESM3InferenceClient,
|
||||||
ESMProtein,
|
ESMProtein,
|
||||||
@@ -32,10 +29,7 @@ from esm.sdk.api import (
|
|||||||
from esm.tokenization import TokenizerCollectionProtocol
|
from esm.tokenization import TokenizerCollectionProtocol
|
||||||
from esm.utils import encoding
|
from esm.utils import encoding
|
||||||
from esm.utils.constants import esm3 as C
|
from esm.utils.constants import esm3 as C
|
||||||
from esm.utils.constants.models import (
|
from esm.utils.constants.models import ESM3_OPEN_SMALL, normalize_model_name
|
||||||
ESM3_OPEN_SMALL,
|
|
||||||
normalize_model_name,
|
|
||||||
)
|
|
||||||
from esm.utils.decoding import decode_protein_tensor
|
from esm.utils.decoding import decode_protein_tensor
|
||||||
from esm.utils.generation import (
|
from esm.utils.generation import (
|
||||||
_batch_forward,
|
_batch_forward,
|
||||||
@@ -50,9 +44,7 @@ from esm.utils.sampling import (
|
|||||||
get_default_sampling_config,
|
get_default_sampling_config,
|
||||||
validate_sampling_config,
|
validate_sampling_config,
|
||||||
)
|
)
|
||||||
from esm.utils.structure.affine3d import (
|
from esm.utils.structure.affine3d import build_affine3d_from_coordinates
|
||||||
build_affine3d_from_coordinates,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -12,9 +12,7 @@ from cloudpathlib import AnyPath
|
|||||||
|
|
||||||
from esm.layers.regression_head import RegressionHead
|
from esm.layers.regression_head import RegressionHead
|
||||||
from esm.layers.transformer_stack import TransformerStack
|
from esm.layers.transformer_stack import TransformerStack
|
||||||
from esm.tokenization.function_tokenizer import (
|
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
|
||||||
InterProQuantizedTokenizer,
|
|
||||||
)
|
|
||||||
from esm.utils.constants import esm3 as C
|
from esm.utils.constants import esm3 as C
|
||||||
from esm.utils.misc import merge_annotations, merge_ranges
|
from esm.utils.misc import merge_annotations, merge_ranges
|
||||||
from esm.utils.types import FunctionAnnotation
|
from esm.utils.types import FunctionAnnotation
|
||||||
|
|||||||
@@ -7,10 +7,7 @@ from esm.layers.structure_proj import Dim6RotStructureHead
|
|||||||
from esm.layers.transformer_stack import TransformerStack
|
from esm.layers.transformer_stack import TransformerStack
|
||||||
from esm.utils.constants import esm3 as C
|
from esm.utils.constants import esm3 as C
|
||||||
from esm.utils.misc import knn_graph
|
from esm.utils.misc import knn_graph
|
||||||
from esm.utils.structure.affine3d import (
|
from esm.utils.structure.affine3d import Affine3D, build_affine3d_from_coordinates
|
||||||
Affine3D,
|
|
||||||
build_affine3d_from_coordinates,
|
|
||||||
)
|
|
||||||
from esm.utils.structure.predicted_aligned_error import (
|
from esm.utils.structure.predicted_aligned_error import (
|
||||||
compute_predicted_aligned_error,
|
compute_predicted_aligned_error,
|
||||||
compute_tm,
|
compute_tm,
|
||||||
|
|||||||
@@ -6,14 +6,8 @@ import torch.nn as nn
|
|||||||
from esm.models.esm3 import ESM3
|
from esm.models.esm3 import ESM3
|
||||||
from esm.models.esmc import ESMC
|
from esm.models.esmc import ESMC
|
||||||
from esm.models.function_decoder import FunctionTokenDecoder
|
from esm.models.function_decoder import FunctionTokenDecoder
|
||||||
from esm.models.vqvae import (
|
from esm.models.vqvae import StructureTokenDecoder, StructureTokenEncoder
|
||||||
StructureTokenDecoder,
|
from esm.tokenization import get_esm3_model_tokenizers, get_esmc_model_tokenizers
|
||||||
StructureTokenEncoder,
|
|
||||||
)
|
|
||||||
from esm.tokenization import (
|
|
||||||
get_esm3_model_tokenizers,
|
|
||||||
get_esmc_model_tokenizers,
|
|
||||||
)
|
|
||||||
from esm.utils.constants.esm3 import data_root
|
from esm.utils.constants.esm3 import data_root
|
||||||
from esm.utils.constants.models import (
|
from esm.utils.constants.models import (
|
||||||
ESM3_FUNCTION_DECODER_V0,
|
ESM3_FUNCTION_DECODER_V0,
|
||||||
|
|||||||
@@ -2,27 +2,19 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import List, Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import torch
|
import torch
|
||||||
from attr import asdict, define
|
from attr import asdict, define
|
||||||
|
|
||||||
import esm.utils.constants.api as C
|
import esm.utils.constants.api as C
|
||||||
from esm.tokenization import (
|
from esm.tokenization import TokenizerCollectionProtocol, get_esm3_model_tokenizers
|
||||||
TokenizerCollectionProtocol,
|
|
||||||
get_esm3_model_tokenizers,
|
|
||||||
)
|
|
||||||
from esm.utils import encoding
|
from esm.utils import encoding
|
||||||
from esm.utils.constants.models import ESM3_OPEN_SMALL
|
from esm.utils.constants.models import ESM3_OPEN_SMALL
|
||||||
from esm.utils.misc import (
|
from esm.utils.misc import get_chainbreak_boundaries_from_sequence
|
||||||
get_chainbreak_boundaries_from_sequence,
|
|
||||||
)
|
|
||||||
from esm.utils.structure.protein_chain import ProteinChain
|
from esm.utils.structure.protein_chain import ProteinChain
|
||||||
from esm.utils.structure.protein_complex import (
|
from esm.utils.structure.protein_complex import SINGLE_LETTER_CHAIN_IDS, ProteinComplex
|
||||||
SINGLE_LETTER_CHAIN_IDS,
|
|
||||||
ProteinComplex,
|
|
||||||
)
|
|
||||||
from esm.utils.types import FunctionAnnotation, PathOrBuffer
|
from esm.utils.types import FunctionAnnotation, PathOrBuffer
|
||||||
|
|
||||||
|
|
||||||
@@ -43,7 +35,6 @@ class ESMProtein(ProteinType):
|
|||||||
plddt: torch.Tensor | None = None
|
plddt: torch.Tensor | None = None
|
||||||
ptm: torch.Tensor | None = None
|
ptm: torch.Tensor | None = None
|
||||||
|
|
||||||
|
|
||||||
# When calling EvolutionaryScale API, use this flag to disclose any
|
# When calling EvolutionaryScale API, use this flag to disclose any
|
||||||
# sequences that may potentially have concerns.
|
# sequences that may potentially have concerns.
|
||||||
# Such sequences may not go through standard safety filter for approved users.
|
# Such sequences may not go through standard safety filter for approved users.
|
||||||
@@ -336,8 +327,6 @@ class InverseFoldingConfig:
|
|||||||
temperature: float = 1.0
|
temperature: float = 1.0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Low Level Endpoint Types
|
## Low Level Endpoint Types
|
||||||
@define
|
@define
|
||||||
class SamplingTrackConfig:
|
class SamplingTrackConfig:
|
||||||
@@ -402,9 +391,6 @@ class LogitsConfig:
|
|||||||
ith_hidden_layer: int = -1
|
ith_hidden_layer: int = -1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@define
|
@define
|
||||||
class LogitsOutput:
|
class LogitsOutput:
|
||||||
logits: ForwardTrackData | None = None
|
logits: ForwardTrackData | None = None
|
||||||
|
|||||||
@@ -1,13 +1,9 @@
|
|||||||
import asyncio
|
|
||||||
import time
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from esm.sdk.api import ESMProteinError
|
from esm.sdk.api import ESMProteinError
|
||||||
from esm.sdk.retry import retry_decorator
|
|
||||||
from esm.utils.decoding import assemble_message
|
from esm.utils.decoding import assemble_message
|
||||||
|
|
||||||
|
|
||||||
@@ -116,10 +112,7 @@ class _BaseForgeInferenceClient:
|
|||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
request, headers = self.prepare_request(
|
request, headers = self.prepare_request(
|
||||||
request,
|
request, potential_sequence_of_concern, return_bytes, headers
|
||||||
potential_sequence_of_concern,
|
|
||||||
return_bytes,
|
|
||||||
headers,
|
|
||||||
)
|
)
|
||||||
response = await self.async_client.post(
|
response = await self.async_client.post(
|
||||||
url=urljoin(self.url, f"/api/v1/{endpoint}"),
|
url=urljoin(self.url, f"/api/v1/{endpoint}"),
|
||||||
@@ -149,10 +142,7 @@ class _BaseForgeInferenceClient:
|
|||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
request, headers = self.prepare_request(
|
request, headers = self.prepare_request(
|
||||||
request,
|
request, potential_sequence_of_concern, return_bytes, headers
|
||||||
potential_sequence_of_concern,
|
|
||||||
return_bytes,
|
|
||||||
headers,
|
|
||||||
)
|
)
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
url=urljoin(self.url, f"/api/v1/{endpoint}"),
|
url=urljoin(self.url, f"/api/v1/{endpoint}"),
|
||||||
@@ -170,5 +160,3 @@ class _BaseForgeInferenceClient:
|
|||||||
error_code=500,
|
error_code=500,
|
||||||
error_msg=f"Failed to submit request to {endpoint}. Error: {e}",
|
error_msg=f"Failed to submit request to {endpoint}. Error: {e}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import asyncio
|
|||||||
import base64
|
import base64
|
||||||
import pickle
|
import pickle
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Any, Literal, Sequence, cast
|
from typing import Any, Sequence
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -20,21 +20,14 @@ from esm.sdk.api import (
|
|||||||
InverseFoldingConfig,
|
InverseFoldingConfig,
|
||||||
LogitsConfig,
|
LogitsConfig,
|
||||||
LogitsOutput,
|
LogitsOutput,
|
||||||
ProteinChain,
|
|
||||||
ProteinType,
|
ProteinType,
|
||||||
SamplingConfig,
|
SamplingConfig,
|
||||||
SamplingTrackConfig,
|
SamplingTrackConfig,
|
||||||
)
|
)
|
||||||
from esm.sdk.base_forge_client import (
|
from esm.sdk.base_forge_client import _BaseForgeInferenceClient
|
||||||
_BaseForgeInferenceClient,
|
|
||||||
)
|
|
||||||
from esm.sdk.retry import retry_decorator
|
from esm.sdk.retry import retry_decorator
|
||||||
from esm.utils.constants.api import MIMETYPE_ES_PICKLE
|
from esm.utils.constants.api import MIMETYPE_ES_PICKLE
|
||||||
from esm.utils.misc import (
|
from esm.utils.misc import deserialize_tensors, maybe_list, maybe_tensor
|
||||||
deserialize_tensors,
|
|
||||||
maybe_list,
|
|
||||||
maybe_tensor,
|
|
||||||
)
|
|
||||||
from esm.utils.msa import MSA
|
from esm.utils.msa import MSA
|
||||||
from esm.utils.structure.input_builder import (
|
from esm.utils.structure.input_builder import (
|
||||||
StructurePredictionInput,
|
StructurePredictionInput,
|
||||||
@@ -108,13 +101,9 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _process_fold_request(
|
def _process_fold_request(sequence: str, model_name: str | None):
|
||||||
sequence: str,
|
|
||||||
model_name: str | None,
|
|
||||||
):
|
|
||||||
request: dict[str, Any] = {"sequence": sequence}
|
request: dict[str, Any] = {"sequence": sequence}
|
||||||
|
|
||||||
|
|
||||||
request["model"] = model_name
|
request["model"] = model_name
|
||||||
|
|
||||||
return request
|
return request
|
||||||
@@ -149,7 +138,6 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
|||||||
|
|
||||||
return request
|
return request
|
||||||
|
|
||||||
|
|
||||||
async def _async_fetch_msa(self, sequence: str) -> MSA:
|
async def _async_fetch_msa(self, sequence: str) -> MSA:
|
||||||
print("Fetching MSA ... this may take a few minutes")
|
print("Fetching MSA ... this may take a few minutes")
|
||||||
# Accept both "|" and ":" as the chainbreak token.
|
# Accept both "|" and ":" as the chainbreak token.
|
||||||
@@ -188,15 +176,11 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
|||||||
del potential_sequence_of_concern
|
del potential_sequence_of_concern
|
||||||
|
|
||||||
request = self._process_fold_request(
|
request = self._process_fold_request(
|
||||||
sequence,
|
sequence, model_name if model_name is not None else self.model
|
||||||
model_name if model_name is not None else self.model,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = await self._async_post(
|
data = await self._async_post("fold", request)
|
||||||
"fold",
|
|
||||||
request,
|
|
||||||
)
|
|
||||||
except ESMProteinError as e:
|
except ESMProteinError as e:
|
||||||
return e
|
return e
|
||||||
|
|
||||||
@@ -223,15 +207,11 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
|||||||
del potential_sequence_of_concern
|
del potential_sequence_of_concern
|
||||||
|
|
||||||
request = self._process_fold_request(
|
request = self._process_fold_request(
|
||||||
sequence,
|
sequence, model_name if model_name is not None else self.model
|
||||||
model_name if model_name is not None else self.model,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = self._post(
|
data = self._post("fold", request)
|
||||||
"fold",
|
|
||||||
request,
|
|
||||||
)
|
|
||||||
except ESMProteinError as e:
|
except ESMProteinError as e:
|
||||||
return e
|
return e
|
||||||
|
|
||||||
@@ -239,9 +219,7 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
|||||||
|
|
||||||
@retry_decorator
|
@retry_decorator
|
||||||
async def async_fold_all_atom(
|
async def async_fold_all_atom(
|
||||||
self,
|
self, all_atom_input: StructurePredictionInput, model_name: str | None = None
|
||||||
all_atom_input: StructurePredictionInput,
|
|
||||||
model_name: str | None = None,
|
|
||||||
) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError:
|
) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError:
|
||||||
"""Fold a molecular complex containing proteins, nucleic acids, and/or ligands.
|
"""Fold a molecular complex containing proteins, nucleic acids, and/or ligands.
|
||||||
|
|
||||||
@@ -250,15 +228,11 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
|||||||
model_name: Override the client level model name if needed
|
model_name: Override the client level model name if needed
|
||||||
"""
|
"""
|
||||||
request = self._process_fold_all_atom_request(
|
request = self._process_fold_all_atom_request(
|
||||||
all_atom_input,
|
all_atom_input, model_name if model_name is not None else self.model
|
||||||
model_name if model_name is not None else self.model,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = await self._async_post(
|
data = await self._async_post("fold_all_atom", request)
|
||||||
"fold_all_atom",
|
|
||||||
request,
|
|
||||||
)
|
|
||||||
except ESMProteinError as e:
|
except ESMProteinError as e:
|
||||||
return e
|
return e
|
||||||
|
|
||||||
@@ -266,9 +240,7 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
|||||||
|
|
||||||
@retry_decorator
|
@retry_decorator
|
||||||
def fold_all_atom(
|
def fold_all_atom(
|
||||||
self,
|
self, all_atom_input: StructurePredictionInput, model_name: str | None = None
|
||||||
all_atom_input: StructurePredictionInput,
|
|
||||||
model_name: str | None = None,
|
|
||||||
) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError:
|
) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError:
|
||||||
"""Predict coordinates for a molecular complex containing proteins, dna, rna, and/or ligands.
|
"""Predict coordinates for a molecular complex containing proteins, dna, rna, and/or ligands.
|
||||||
|
|
||||||
@@ -277,15 +249,11 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
|||||||
model_name: Override the client level model name if needed
|
model_name: Override the client level model name if needed
|
||||||
"""
|
"""
|
||||||
request = self._process_fold_all_atom_request(
|
request = self._process_fold_all_atom_request(
|
||||||
all_atom_input,
|
all_atom_input, model_name if model_name is not None else self.model
|
||||||
model_name if model_name is not None else self.model,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = self._post(
|
data = self._post("fold_all_atom", request)
|
||||||
"fold_all_atom",
|
|
||||||
request,
|
|
||||||
)
|
|
||||||
except ESMProteinError as e:
|
except ESMProteinError as e:
|
||||||
return e
|
return e
|
||||||
|
|
||||||
@@ -293,15 +261,13 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _process_fold_all_atom_request(
|
def _process_fold_all_atom_request(
|
||||||
all_atom_input: StructurePredictionInput,
|
all_atom_input: StructurePredictionInput, model_name: str | None = None
|
||||||
model_name: str | None = None,
|
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
request: dict[str, Any] = {
|
request: dict[str, Any] = {
|
||||||
"all_atom_input": serialize_structure_prediction_input(all_atom_input),
|
"all_atom_input": serialize_structure_prediction_input(all_atom_input),
|
||||||
"model": model_name,
|
"model": model_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
return request
|
return request
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -386,7 +352,6 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
|||||||
return ESMProtein(sequence=data["sequence"])
|
return ESMProtein(sequence=data["sequence"])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ESM3ForgeInferenceClient(ESM3InferenceClient, _BaseForgeInferenceClient):
|
class ESM3ForgeInferenceClient(ESM3InferenceClient, _BaseForgeInferenceClient):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -1212,5 +1177,3 @@ class ESMCForgeInferenceClient(ESMCInferenceClient, _BaseForgeInferenceClient):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Can not get underlying remote model {self.model} from a Forge client."
|
f"Can not get underlying remote model {self.model} from a Forge client."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,7 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
from esm.utils.constants.models import (
|
from esm.utils.constants.models import ESM3_OPEN_SMALL, normalize_model_name
|
||||||
ESM3_OPEN_SMALL,
|
|
||||||
normalize_model_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .function_tokenizer import InterProQuantizedTokenizer
|
from .function_tokenizer import InterProQuantizedTokenizer
|
||||||
from .residue_tokenizer import ResidueAnnotationsTokenizer
|
from .residue_tokenizer import ResidueAnnotationsTokenizer
|
||||||
|
|||||||
@@ -10,24 +10,12 @@ from esm.models.function_decoder import FunctionTokenDecoder
|
|||||||
from esm.models.vqvae import StructureTokenDecoder
|
from esm.models.vqvae import StructureTokenDecoder
|
||||||
from esm.sdk.api import ESMProtein, ESMProteinTensor
|
from esm.sdk.api import ESMProtein, ESMProteinTensor
|
||||||
from esm.tokenization import TokenizerCollectionProtocol
|
from esm.tokenization import TokenizerCollectionProtocol
|
||||||
from esm.tokenization.function_tokenizer import (
|
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
|
||||||
InterProQuantizedTokenizer,
|
from esm.tokenization.residue_tokenizer import ResidueAnnotationsTokenizer
|
||||||
)
|
from esm.tokenization.sasa_tokenizer import SASADiscretizingTokenizer
|
||||||
from esm.tokenization.residue_tokenizer import (
|
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
|
||||||
ResidueAnnotationsTokenizer,
|
from esm.tokenization.ss_tokenizer import SecondaryStructureTokenizer
|
||||||
)
|
from esm.tokenization.structure_tokenizer import StructureTokenizer
|
||||||
from esm.tokenization.sasa_tokenizer import (
|
|
||||||
SASADiscretizingTokenizer,
|
|
||||||
)
|
|
||||||
from esm.tokenization.sequence_tokenizer import (
|
|
||||||
EsmSequenceTokenizer,
|
|
||||||
)
|
|
||||||
from esm.tokenization.ss_tokenizer import (
|
|
||||||
SecondaryStructureTokenizer,
|
|
||||||
)
|
|
||||||
from esm.tokenization.structure_tokenizer import (
|
|
||||||
StructureTokenizer,
|
|
||||||
)
|
|
||||||
from esm.tokenization.tokenizer_base import EsmTokenizerBase
|
from esm.tokenization.tokenizer_base import EsmTokenizerBase
|
||||||
from esm.utils.constants import api as api_constants
|
from esm.utils.constants import api as api_constants
|
||||||
from esm.utils.constants import esm3 as C
|
from esm.utils.constants import esm3 as C
|
||||||
|
|||||||
@@ -7,26 +7,13 @@ from esm.models.vqvae import StructureTokenEncoder
|
|||||||
from esm.tokenization.function_tokenizer import (
|
from esm.tokenization.function_tokenizer import (
|
||||||
InterProQuantizedTokenizer as EsmFunctionTokenizer,
|
InterProQuantizedTokenizer as EsmFunctionTokenizer,
|
||||||
)
|
)
|
||||||
|
from esm.tokenization.residue_tokenizer import ResidueAnnotationsTokenizer
|
||||||
from esm.tokenization.residue_tokenizer import (
|
from esm.tokenization.sasa_tokenizer import SASADiscretizingTokenizer
|
||||||
ResidueAnnotationsTokenizer,
|
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
|
||||||
)
|
from esm.tokenization.ss_tokenizer import SecondaryStructureTokenizer
|
||||||
from esm.tokenization.sasa_tokenizer import (
|
from esm.tokenization.structure_tokenizer import StructureTokenizer
|
||||||
SASADiscretizingTokenizer,
|
|
||||||
)
|
|
||||||
from esm.tokenization.sequence_tokenizer import (
|
|
||||||
EsmSequenceTokenizer,
|
|
||||||
)
|
|
||||||
from esm.tokenization.ss_tokenizer import (
|
|
||||||
SecondaryStructureTokenizer,
|
|
||||||
)
|
|
||||||
from esm.tokenization.structure_tokenizer import (
|
|
||||||
StructureTokenizer,
|
|
||||||
)
|
|
||||||
from esm.utils.constants import esm3 as C
|
from esm.utils.constants import esm3 as C
|
||||||
from esm.utils.function.encode_decode import (
|
from esm.utils.function.encode_decode import encode_function_annotations
|
||||||
encode_function_annotations,
|
|
||||||
)
|
|
||||||
from esm.utils.structure.protein_chain import ProteinChain
|
from esm.utils.structure.protein_chain import ProteinChain
|
||||||
from esm.utils.types import FunctionAnnotation
|
from esm.utils.types import FunctionAnnotation
|
||||||
|
|
||||||
@@ -165,8 +152,6 @@ def tokenize_function_annotations(
|
|||||||
return function_tokens, residue_annotation_tokens
|
return function_tokens, residue_annotation_tokens
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Tokenized Defaults
|
# Tokenized Defaults
|
||||||
def get_default_sequence_tokens(
|
def get_default_sequence_tokens(
|
||||||
sequence_length: int, sequence_tokenizer: EsmSequenceTokenizer
|
sequence_length: int, sequence_tokenizer: EsmSequenceTokenizer
|
||||||
@@ -242,5 +227,3 @@ def get_default_residue_annotation_tokens(
|
|||||||
residue_annotation_tokens[0] = residue_annotation_tokenizer.bos_token_id
|
residue_annotation_tokens[0] = residue_annotation_tokenizer.bos_token_id
|
||||||
residue_annotation_tokens[-1] = residue_annotation_tokenizer.eos_token_id
|
residue_annotation_tokens[-1] = residue_annotation_tokenizer.eos_token_id
|
||||||
return residue_annotation_tokens
|
return residue_annotation_tokens
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,10 +7,7 @@ from typing import Any, Callable, List
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from esm.sdk.api import ESMProteinError
|
from esm.sdk.api import ESMProteinError
|
||||||
from esm.sdk.retry import (
|
from esm.sdk.retry import retry_if_specific_error, skip_retries_var
|
||||||
retry_if_specific_error,
|
|
||||||
skip_retries_var,
|
|
||||||
)
|
|
||||||
|
|
||||||
TQDM_BAR_FORMAT = (
|
TQDM_BAR_FORMAT = (
|
||||||
"{desc:<12}{percentage:3.0f}%|{bar:24}| {n_fmt}/{total_fmt} "
|
"{desc:<12}{percentage:3.0f}%|{bar:24}| {n_fmt}/{total_fmt} "
|
||||||
|
|||||||
@@ -3,16 +3,9 @@ from typing import Sequence
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from esm.models.function_decoder import (
|
from esm.models.function_decoder import FunctionTokenDecoder, merge_annotations
|
||||||
FunctionTokenDecoder,
|
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
|
||||||
merge_annotations,
|
from esm.tokenization.residue_tokenizer import ResidueAnnotationsTokenizer
|
||||||
)
|
|
||||||
from esm.tokenization.function_tokenizer import (
|
|
||||||
InterProQuantizedTokenizer,
|
|
||||||
)
|
|
||||||
from esm.tokenization.residue_tokenizer import (
|
|
||||||
ResidueAnnotationsTokenizer,
|
|
||||||
)
|
|
||||||
from esm.utils.constants import esm3 as C
|
from esm.utils.constants import esm3 as C
|
||||||
from esm.utils.types import FunctionAnnotation
|
from esm.utils.types import FunctionAnnotation
|
||||||
|
|
||||||
|
|||||||
@@ -19,13 +19,8 @@ from esm.sdk.api import (
|
|||||||
SamplingConfig,
|
SamplingConfig,
|
||||||
SamplingTrackConfig,
|
SamplingTrackConfig,
|
||||||
)
|
)
|
||||||
from esm.tokenization import (
|
from esm.tokenization import EsmTokenizerBase, TokenizerCollectionProtocol
|
||||||
EsmTokenizerBase,
|
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
|
||||||
TokenizerCollectionProtocol,
|
|
||||||
)
|
|
||||||
from esm.tokenization.function_tokenizer import (
|
|
||||||
InterProQuantizedTokenizer,
|
|
||||||
)
|
|
||||||
from esm.utils.constants import esm3 as C
|
from esm.utils.constants import esm3 as C
|
||||||
from esm.utils.misc import stack_variable_length_tensors
|
from esm.utils.misc import stack_variable_length_tensors
|
||||||
from esm.utils.noise_schedules import NOISE_SCHEDULE_REGISTRY
|
from esm.utils.noise_schedules import NOISE_SCHEDULE_REGISTRY
|
||||||
|
|||||||
@@ -1,7 +1,3 @@
|
|||||||
from esm.utils.msa.msa import (
|
from esm.utils.msa.msa import MSA, FastMSA, remove_insertions_from_sequence
|
||||||
MSA,
|
|
||||||
FastMSA,
|
|
||||||
remove_insertions_from_sequence,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = ["MSA", "FastMSA", "remove_insertions_from_sequence"]
|
__all__ = ["MSA", "FastMSA", "remove_insertions_from_sequence"]
|
||||||
|
|||||||
@@ -12,15 +12,8 @@ from Bio import SeqIO
|
|||||||
from scipy.spatial.distance import cdist
|
from scipy.spatial.distance import cdist
|
||||||
|
|
||||||
from esm.utils.misc import slice_any_object
|
from esm.utils.misc import slice_any_object
|
||||||
from esm.utils.msa.filter_sequences import (
|
from esm.utils.msa.filter_sequences import greedy_select_indices, hhfilter
|
||||||
greedy_select_indices,
|
from esm.utils.parsing import FastaEntry, read_sequences, write_sequences
|
||||||
hhfilter,
|
|
||||||
)
|
|
||||||
from esm.utils.parsing import (
|
|
||||||
FastaEntry,
|
|
||||||
read_sequences,
|
|
||||||
write_sequences,
|
|
||||||
)
|
|
||||||
from esm.utils.sequential_dataclass import SequentialDataclass
|
from esm.utils.sequential_dataclass import SequentialDataclass
|
||||||
from esm.utils.system import PathOrBuffer
|
from esm.utils.system import PathOrBuffer
|
||||||
|
|
||||||
|
|||||||
@@ -5,18 +5,9 @@ import attr
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from esm.sdk.api import (
|
from esm.sdk.api import ESMProteinTensor, SamplingConfig, SamplingTrackConfig
|
||||||
ESMProteinTensor,
|
from esm.tokenization import TokenizerCollectionProtocol, get_invalid_tokenizer_ids
|
||||||
SamplingConfig,
|
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
|
||||||
SamplingTrackConfig,
|
|
||||||
)
|
|
||||||
from esm.tokenization import (
|
|
||||||
TokenizerCollectionProtocol,
|
|
||||||
get_invalid_tokenizer_ids,
|
|
||||||
)
|
|
||||||
from esm.tokenization.function_tokenizer import (
|
|
||||||
InterProQuantizedTokenizer,
|
|
||||||
)
|
|
||||||
from esm.utils.constants.esm3 import (
|
from esm.utils.constants.esm3 import (
|
||||||
MAX_RESIDUE_ANNOTATIONS,
|
MAX_RESIDUE_ANNOTATIONS,
|
||||||
SASA_DISCRETIZATION_BOUNDARIES,
|
SASA_DISCRETIZATION_BOUNDARIES,
|
||||||
|
|||||||
@@ -6,9 +6,7 @@ from typing import Any, ClassVar, Protocol, TypeVar
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from esm.utils.structure.protein_structure import (
|
from esm.utils.structure.protein_structure import compute_affine_and_rmsd
|
||||||
compute_affine_and_rmsd,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Alignable(Protocol):
|
class Alignable(Protocol):
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from esm.utils.structure.protein_structure import (
|
from esm.utils.structure.protein_structure import index_by_atom_name
|
||||||
index_by_atom_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AtomIndexer:
|
class AtomIndexer:
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ from typing import Any, Sequence
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Modification:
|
class Modification:
|
||||||
position: int # zero-indexed
|
position: int # zero-indexed
|
||||||
|
|||||||
@@ -16,14 +16,8 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from esm.utils import residue_constants
|
from esm.utils import residue_constants
|
||||||
from esm.utils.structure.metrics import (
|
from esm.utils.structure.metrics import compute_lddt, compute_rmsd
|
||||||
compute_lddt,
|
from esm.utils.structure.protein_complex import ProteinComplex, ProteinComplexMetadata
|
||||||
compute_rmsd,
|
|
||||||
)
|
|
||||||
from esm.utils.structure.protein_complex import (
|
|
||||||
ProteinComplex,
|
|
||||||
ProteinComplexMetadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -25,21 +25,13 @@ from esm.utils.misc import slice_python_object_as_numpy
|
|||||||
from esm.utils.structure.affine3d import Affine3D
|
from esm.utils.structure.affine3d import Affine3D
|
||||||
from esm.utils.structure.aligner import Aligner
|
from esm.utils.structure.aligner import Aligner
|
||||||
from esm.utils.structure.atom_indexer import AtomIndexer
|
from esm.utils.structure.atom_indexer import AtomIndexer
|
||||||
from esm.utils.structure.metrics import (
|
from esm.utils.structure.metrics import compute_gdt_ts, compute_lddt_ca
|
||||||
compute_gdt_ts,
|
from esm.utils.structure.mmcif_parsing import MmcifWrapper, Residue
|
||||||
compute_lddt_ca,
|
|
||||||
)
|
|
||||||
from esm.utils.structure.mmcif_parsing import (
|
|
||||||
MmcifWrapper,
|
|
||||||
Residue,
|
|
||||||
)
|
|
||||||
from esm.utils.structure.normalize_coordinates import (
|
from esm.utils.structure.normalize_coordinates import (
|
||||||
apply_frame_to_coords,
|
apply_frame_to_coords,
|
||||||
get_protein_normalization_frame,
|
get_protein_normalization_frame,
|
||||||
)
|
)
|
||||||
from esm.utils.structure.protein_structure import (
|
from esm.utils.structure.protein_structure import index_by_atom_name
|
||||||
index_by_atom_name,
|
|
||||||
)
|
|
||||||
from esm.utils.types import PathOrBuffer
|
from esm.utils.types import PathOrBuffer
|
||||||
|
|
||||||
msgpack_numpy.patch()
|
msgpack_numpy.patch()
|
||||||
@@ -401,7 +393,6 @@ class ProteinChain:
|
|||||||
bytes = input
|
bytes = input
|
||||||
return cls.from_state_dict(msgpack.loads(brotli.decompress(bytes)))
|
return cls.from_state_dict(msgpack.loads(brotli.decompress(bytes)))
|
||||||
|
|
||||||
|
|
||||||
def sasa(self, by_residue: bool = True):
|
def sasa(self, by_residue: bool = True):
|
||||||
arr = self.atom_array_no_insertions
|
arr = self.atom_array_no_insertions
|
||||||
sasa_per_atom = bs.sasa(arr) # type: ignore
|
sasa_per_atom = bs.sasa(arr) # type: ignore
|
||||||
@@ -707,7 +698,6 @@ class ProteinChain:
|
|||||||
)
|
)
|
||||||
return float(gdt_ts) if gdt_ts.numel() == 1 else gdt_ts.numpy().flatten()
|
return float(gdt_ts) if gdt_ts.numel() == 1 else gdt_ts.numpy().flatten()
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def chain_iterable_from_mmcif(
|
def chain_iterable_from_mmcif(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
@@ -32,14 +32,8 @@ from esm.utils.misc import slice_python_object_as_numpy
|
|||||||
from esm.utils.structure.affine3d import Affine3D
|
from esm.utils.structure.affine3d import Affine3D
|
||||||
from esm.utils.structure.aligner import Aligner
|
from esm.utils.structure.aligner import Aligner
|
||||||
from esm.utils.structure.atom_indexer import AtomIndexer
|
from esm.utils.structure.atom_indexer import AtomIndexer
|
||||||
from esm.utils.structure.metrics import (
|
from esm.utils.structure.metrics import compute_gdt_ts, compute_lddt_ca
|
||||||
compute_gdt_ts,
|
from esm.utils.structure.mmcif_parsing import MmcifWrapper, NoProteinError
|
||||||
compute_lddt_ca,
|
|
||||||
)
|
|
||||||
from esm.utils.structure.mmcif_parsing import (
|
|
||||||
MmcifWrapper,
|
|
||||||
NoProteinError,
|
|
||||||
)
|
|
||||||
from esm.utils.structure.protein_chain import (
|
from esm.utils.structure.protein_chain import (
|
||||||
ProteinChain,
|
ProteinChain,
|
||||||
chain_to_ndarray,
|
chain_to_ndarray,
|
||||||
|
|||||||
@@ -4,9 +4,7 @@ import pygtrie
|
|||||||
from ipywidgets import widgets
|
from ipywidgets import widgets
|
||||||
|
|
||||||
from esm.sdk.api import FunctionAnnotation
|
from esm.sdk.api import FunctionAnnotation
|
||||||
from esm.tokenization.function_tokenizer import (
|
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
|
||||||
InterProQuantizedTokenizer,
|
|
||||||
)
|
|
||||||
|
|
||||||
TRIE: pygtrie.CharTrie | None = None
|
TRIE: pygtrie.CharTrie | None = None
|
||||||
|
|
||||||
|
|||||||
@@ -7,15 +7,11 @@ import matplotlib.colors as mcolors
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
from esm.sdk.api import ESMProtein
|
from esm.sdk.api import ESMProtein
|
||||||
from esm.widgets.utils.drawing.draw_category_array import (
|
from esm.widgets.utils.drawing.draw_category_array import draw_data_array
|
||||||
draw_data_array,
|
|
||||||
)
|
|
||||||
from esm.widgets.utils.drawing.draw_function_annotations import (
|
from esm.widgets.utils.drawing.draw_function_annotations import (
|
||||||
draw_function_annotations,
|
draw_function_annotations,
|
||||||
)
|
)
|
||||||
from esm.widgets.utils.drawing.draw_protein_structure import (
|
from esm.widgets.utils.drawing.draw_protein_structure import draw_protein_structure
|
||||||
draw_protein_structure,
|
|
||||||
)
|
|
||||||
from esm.widgets.utils.serialization import (
|
from esm.widgets.utils.serialization import (
|
||||||
create_download_button_from_buffer,
|
create_download_button_from_buffer,
|
||||||
protein_to_pdb_buffer,
|
protein_to_pdb_buffer,
|
||||||
|
|||||||
@@ -3,16 +3,9 @@ from typing import Any, Callable, Sequence
|
|||||||
import ipywidgets as widgets
|
import ipywidgets as widgets
|
||||||
|
|
||||||
from esm.utils.structure.protein_chain import ProteinChain
|
from esm.utils.structure.protein_chain import ProteinChain
|
||||||
from esm.widgets.utils.drawing.colors import (
|
from esm.widgets.utils.drawing.colors import hex_to_rgba_tuple, rgba_tuple_to_hex
|
||||||
hex_to_rgba_tuple,
|
from esm.widgets.utils.drawing.draw_category_array import draw_data_array
|
||||||
rgba_tuple_to_hex,
|
from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges
|
||||||
)
|
|
||||||
from esm.widgets.utils.drawing.draw_category_array import (
|
|
||||||
draw_data_array,
|
|
||||||
)
|
|
||||||
from esm.widgets.utils.parsing import (
|
|
||||||
convert_range_string_to_list_of_ranges,
|
|
||||||
)
|
|
||||||
from esm.widgets.utils.prompting import PromptManager
|
from esm.widgets.utils.prompting import PromptManager
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,16 +4,9 @@ import ipywidgets as widgets
|
|||||||
import pydssp
|
import pydssp
|
||||||
|
|
||||||
from esm.utils.structure.protein_chain import ProteinChain
|
from esm.utils.structure.protein_chain import ProteinChain
|
||||||
from esm.widgets.utils.drawing.colors import (
|
from esm.widgets.utils.drawing.colors import hex_to_rgba_tuple, rgba_tuple_to_hex
|
||||||
hex_to_rgba_tuple,
|
from esm.widgets.utils.drawing.draw_category_array import draw_data_array
|
||||||
rgba_tuple_to_hex,
|
from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges
|
||||||
)
|
|
||||||
from esm.widgets.utils.drawing.draw_category_array import (
|
|
||||||
draw_data_array,
|
|
||||||
)
|
|
||||||
from esm.widgets.utils.parsing import (
|
|
||||||
convert_range_string_to_list_of_ranges,
|
|
||||||
)
|
|
||||||
from esm.widgets.utils.prompting import PromptManager
|
from esm.widgets.utils.prompting import PromptManager
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,9 +6,7 @@ from esm.widgets.utils.drawing.colors import (
|
|||||||
hex_to_rgba_tuple,
|
hex_to_rgba_tuple,
|
||||||
rgba_tuple_to_rgba_html_string,
|
rgba_tuple_to_rgba_html_string,
|
||||||
)
|
)
|
||||||
from esm.widgets.utils.parsing import (
|
from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges
|
||||||
convert_range_string_to_list_of_ranges,
|
|
||||||
)
|
|
||||||
from esm.widgets.utils.prompting import PromptManager
|
from esm.widgets.utils.prompting import PromptManager
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -10,12 +10,8 @@ from matplotlib.patches import Rectangle
|
|||||||
|
|
||||||
from esm.utils.structure.protein_chain import ProteinChain
|
from esm.utils.structure.protein_chain import ProteinChain
|
||||||
from esm.widgets.utils import indexing
|
from esm.widgets.utils import indexing
|
||||||
from esm.widgets.utils.drawing.draw_protein_structure import (
|
from esm.widgets.utils.drawing.draw_protein_structure import draw_protein_structure
|
||||||
draw_protein_structure,
|
from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges
|
||||||
)
|
|
||||||
from esm.widgets.utils.parsing import (
|
|
||||||
convert_range_string_to_list_of_ranges,
|
|
||||||
)
|
|
||||||
from esm.widgets.utils.printing import wrapped_print
|
from esm.widgets.utils.printing import wrapped_print
|
||||||
from esm.widgets.utils.prompting import PromptManager
|
from esm.widgets.utils.prompting import PromptManager
|
||||||
|
|
||||||
|
|||||||
@@ -9,10 +9,7 @@ from matplotlib import colormaps
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from esm.sdk.api import FunctionAnnotation
|
from esm.sdk.api import FunctionAnnotation
|
||||||
from esm.utils.function.interpro import (
|
from esm.utils.function.interpro import InterPro, InterProEntryType
|
||||||
InterPro,
|
|
||||||
InterProEntryType,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
|||||||
@@ -9,9 +9,7 @@ from esm.sdk.api import ESMProtein, FunctionAnnotation
|
|||||||
from esm.utils import encoding
|
from esm.utils import encoding
|
||||||
from esm.widgets.utils import indexing
|
from esm.widgets.utils import indexing
|
||||||
from esm.widgets.utils.drawing.colors import rgba_tuple_to_hex
|
from esm.widgets.utils.drawing.colors import rgba_tuple_to_hex
|
||||||
from esm.widgets.utils.drawing.draw_category_array import (
|
from esm.widgets.utils.drawing.draw_category_array import draw_data_array
|
||||||
draw_data_array,
|
|
||||||
)
|
|
||||||
from esm.widgets.utils.printing import wrapped_print
|
from esm.widgets.utils.printing import wrapped_print
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -13,13 +13,9 @@ from esm.sdk.api import (
|
|||||||
GenerationConfig,
|
GenerationConfig,
|
||||||
)
|
)
|
||||||
from esm.utils.constants import models
|
from esm.utils.constants import models
|
||||||
from esm.widgets.components.results_visualizer import (
|
from esm.widgets.components.results_visualizer import create_results_visualizer
|
||||||
create_results_visualizer,
|
|
||||||
)
|
|
||||||
from esm.widgets.utils.printing import wrapped_print
|
from esm.widgets.utils.printing import wrapped_print
|
||||||
from esm.widgets.utils.serialization import (
|
from esm.widgets.utils.serialization import create_download_results_button
|
||||||
create_download_results_button,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_esm3_generation_launcher(
|
def create_esm3_generation_launcher(
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
from ipywidgets import widgets
|
from ipywidgets import widgets
|
||||||
|
|
||||||
from esm.widgets.components.sasa_prompt_selector import (
|
from esm.widgets.components.sasa_prompt_selector import create_sasa_prompt_selector
|
||||||
create_sasa_prompt_selector,
|
|
||||||
)
|
|
||||||
from esm.widgets.components.secondary_structure_prompt_selector import (
|
from esm.widgets.components.secondary_structure_prompt_selector import (
|
||||||
create_secondary_structure_prompt_selector,
|
create_secondary_structure_prompt_selector,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,20 +4,12 @@ from ipywidgets import widgets
|
|||||||
|
|
||||||
from esm.sdk.api import ESM3InferenceClient, ESMProtein
|
from esm.sdk.api import ESM3InferenceClient, ESMProtein
|
||||||
from esm.utils.constants import esm3 as C
|
from esm.utils.constants import esm3 as C
|
||||||
from esm.widgets.components.function_annotator import (
|
from esm.widgets.components.function_annotator import create_function_annotator
|
||||||
create_function_annotator,
|
|
||||||
)
|
|
||||||
from esm.widgets.utils.prompting import PromptManagerCollection
|
from esm.widgets.utils.prompting import PromptManagerCollection
|
||||||
from esm.widgets.utils.protein_import import ProteinImporter
|
from esm.widgets.utils.protein_import import ProteinImporter
|
||||||
from esm.widgets.views.esm3_generation_launcher import (
|
from esm.widgets.views.esm3_generation_launcher import create_esm3_generation_launcher
|
||||||
create_esm3_generation_launcher,
|
from esm.widgets.views.esm3_prompt_preview import create_esm3_prompt_preview
|
||||||
)
|
from esm.widgets.views.esm3_prompt_selector import create_esm3_prompt_selector
|
||||||
from esm.widgets.views.esm3_prompt_preview import (
|
|
||||||
create_esm3_prompt_preview,
|
|
||||||
)
|
|
||||||
from esm.widgets.views.esm3_prompt_selector import (
|
|
||||||
create_esm3_prompt_selector,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_generation_ui(
|
def create_generation_ui(
|
||||||
|
|||||||
@@ -6,9 +6,7 @@ from esm.sdk.api import (
|
|||||||
ESMProteinError,
|
ESMProteinError,
|
||||||
GenerationConfig,
|
GenerationConfig,
|
||||||
)
|
)
|
||||||
from esm.widgets.components.results_visualizer import (
|
from esm.widgets.components.results_visualizer import create_results_visualizer
|
||||||
create_results_visualizer,
|
|
||||||
)
|
|
||||||
from esm.widgets.utils.printing import wrapped_print
|
from esm.widgets.utils.printing import wrapped_print
|
||||||
from esm.widgets.utils.protein_import import ProteinImporter
|
from esm.widgets.utils.protein_import import ProteinImporter
|
||||||
|
|
||||||
|
|||||||
@@ -4,10 +4,7 @@ from textwrap import dedent
|
|||||||
|
|
||||||
from ipywidgets import widgets
|
from ipywidgets import widgets
|
||||||
|
|
||||||
from esm.widgets.utils.clients import (
|
from esm.widgets.utils.clients import get_forge_client, get_local_client
|
||||||
get_forge_client,
|
|
||||||
get_local_client,
|
|
||||||
)
|
|
||||||
from esm.widgets.utils.types import ClientInitContainer
|
from esm.widgets.utils.types import ClientInitContainer
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,9 +6,7 @@ from esm.sdk.api import (
|
|||||||
ESMProteinError,
|
ESMProteinError,
|
||||||
GenerationConfig,
|
GenerationConfig,
|
||||||
)
|
)
|
||||||
from esm.widgets.components.results_visualizer import (
|
from esm.widgets.components.results_visualizer import create_results_visualizer
|
||||||
create_results_visualizer,
|
|
||||||
)
|
|
||||||
from esm.widgets.utils.printing import wrapped_print
|
from esm.widgets.utils.printing import wrapped_print
|
||||||
from esm.widgets.utils.protein_import import ProteinImporter
|
from esm.widgets.utils.protein_import import ProteinImporter
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import os
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from esm.sdk import client # pyright: ignore
|
from esm.sdk import client # pyright: ignore
|
||||||
from esm.sdk.api import ( # pyright: ignore
|
from esm.sdk.api import ( # pyright: ignore
|
||||||
ESMProtein,
|
ESMProtein,
|
||||||
|
|||||||
Reference in New Issue
Block a user