mirror of
https://github.com/evolutionaryscale/esm.git
synced 2026-06-04 17:14:23 +08:00
sync 3.2.2.post1 (#270)
This commit is contained in:
@@ -38,7 +38,6 @@
|
||||
"\n",
|
||||
"!pip install py3Dmol\n",
|
||||
"import py3Dmol\n",
|
||||
"\n",
|
||||
"from esm.models.esm3 import ESM3\n",
|
||||
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
|
||||
"from esm.utils.structure.protein_chain import ProteinChain"
|
||||
|
||||
@@ -2,7 +2,6 @@ import random
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from esm.pretrained import (
|
||||
ESM3_function_decoder_v0,
|
||||
ESM3_sm_open_v0,
|
||||
@@ -13,7 +12,9 @@ from esm.tokenization import get_esm3_model_tokenizers
|
||||
from esm.tokenization.function_tokenizer import (
|
||||
InterProQuantizedTokenizer as EsmFunctionTokenizer,
|
||||
)
|
||||
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
|
||||
from esm.tokenization.sequence_tokenizer import (
|
||||
EsmSequenceTokenizer,
|
||||
)
|
||||
from esm.utils.structure.protein_chain import ProteinChain
|
||||
from esm.utils.types import FunctionAnnotation
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ import os
|
||||
from typing import cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
from esm.sdk.api import (
|
||||
ESM3InferenceClient,
|
||||
ESMProtein,
|
||||
|
||||
@@ -72,7 +72,6 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from biotite.database import rcsb\n",
|
||||
"\n",
|
||||
"from esm.sdk.api import ESMProtein\n",
|
||||
"from esm.utils.structure.protein_chain import ProteinChain\n",
|
||||
"from esm.utils.types import FunctionAnnotation\n",
|
||||
@@ -497,9 +496,8 @@
|
||||
"# Functions for visualizing InterPro function annotations\n",
|
||||
"\n",
|
||||
"from dna_features_viewer import GraphicFeature, GraphicRecord\n",
|
||||
"from matplotlib import colormaps\n",
|
||||
"\n",
|
||||
"from esm.utils.function.interpro import InterPro, InterProEntryType\n",
|
||||
"from matplotlib import colormaps\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def visualize_function_annotations(\n",
|
||||
|
||||
@@ -49,18 +49,18 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories."
|
||||
"Grab a token from [Forge](https://forge.evolutionaryscale.ai/) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from getpass import getpass\n",
|
||||
"\n",
|
||||
"token = getpass(\"Token from Forge console: \")"
|
||||
"token = getpass(\"Token from Forge: \")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -64,7 +64,6 @@
|
||||
"import matplotlib.pyplot as pl\n",
|
||||
"import py3Dmol\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"from esm.sdk import client\n",
|
||||
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
|
||||
"from esm.utils.structure.protein_chain import ProteinChain"
|
||||
@@ -80,18 +79,18 @@
|
||||
"\n",
|
||||
"The largest ESM3 (98 billion parameters) was trained with 1.07e24 FLOPs on 2.78 billion proteins and 771 billion unique tokens. To create esmGFP we used the 7 billion parameter variant of ESM3. We'll use this model via the [EvolutionaryScale Forge](https://forge.evolutionaryscale.ai) API.\n",
|
||||
"\n",
|
||||
"Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories.\n"
|
||||
"Grab a token from [Forge](https://forge.evolutionaryscale.ai/) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "zNrU9Q2SYonX"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"token = getpass(\"Token from Forge console: \")"
|
||||
"token = getpass(\"Token from Forge: \")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -36,7 +36,6 @@
|
||||
"\n",
|
||||
"!pip install py3Dmol\n",
|
||||
"import py3Dmol\n",
|
||||
"\n",
|
||||
"from esm.sdk import client\n",
|
||||
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
|
||||
"from esm.utils.structure.protein_chain import ProteinChain"
|
||||
@@ -53,7 +52,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories."
|
||||
"Grab a token from [Forge](https://forge.evolutionaryscale.ai/) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -64,7 +63,7 @@
|
||||
"source": [
|
||||
"from getpass import getpass\n",
|
||||
"\n",
|
||||
"token = getpass(\"Token from Forge console: \")\n",
|
||||
"token = getpass(\"Token from Forge: \")\n",
|
||||
"model = client(model=\"esm3-open\", url=\"https://forge.evolutionaryscale.ai\", token=token)"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -49,7 +49,6 @@
|
||||
"source": [
|
||||
"import biotite.structure as bs\n",
|
||||
"import py3Dmol\n",
|
||||
"\n",
|
||||
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
|
||||
"from esm.sdk.experimental import ESM3GuidedDecoding, GuidedDecodingScoringFunction"
|
||||
]
|
||||
@@ -120,7 +119,7 @@
|
||||
"\n",
|
||||
"from esm.sdk import client\n",
|
||||
"\n",
|
||||
"token = getpass(\"Token from Forge console: \")\n",
|
||||
"token = getpass(\"Token from Forge: \")\n",
|
||||
"model = client(\n",
|
||||
" model=\"esm3-medium-2024-08\", url=\"https://forge.evolutionaryscale.ai\", token=token\n",
|
||||
")"
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
__version__ = "3.2.2"
|
||||
__version__ = "3.2.2.post1"
|
||||
|
||||
|
||||
@@ -5,7 +5,10 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from esm.layers.rotary import RotaryEmbedding, TritonRotaryEmbedding
|
||||
from esm.layers.rotary import (
|
||||
RotaryEmbedding,
|
||||
TritonRotaryEmbedding,
|
||||
)
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_qkvpacked_func
|
||||
|
||||
@@ -2,8 +2,13 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from esm.layers.attention import FlashMultiHeadAttention, MultiHeadAttention
|
||||
from esm.layers.geom_attention import GeometricReasoningOriginalImpl
|
||||
from esm.layers.attention import (
|
||||
FlashMultiHeadAttention,
|
||||
MultiHeadAttention,
|
||||
)
|
||||
from esm.layers.geom_attention import (
|
||||
GeometricReasoningOriginalImpl,
|
||||
)
|
||||
from esm.utils.structure.affine3d import Affine3D
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,10 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from esm.utils.constants.physics import BB_COORDINATES
|
||||
from esm.utils.structure.affine3d import Affine3D, RotationMatrix
|
||||
from esm.utils.structure.affine3d import (
|
||||
Affine3D,
|
||||
RotationMatrix,
|
||||
)
|
||||
|
||||
|
||||
class Dim6RotStructureHead(nn.Module):
|
||||
|
||||
@@ -13,7 +13,10 @@ from attr import dataclass
|
||||
from esm.layers.regression_head import RegressionHead
|
||||
from esm.layers.transformer_stack import TransformerStack
|
||||
from esm.models.function_decoder import FunctionTokenDecoder
|
||||
from esm.models.vqvae import StructureTokenDecoder, StructureTokenEncoder
|
||||
from esm.models.vqvae import (
|
||||
StructureTokenDecoder,
|
||||
StructureTokenEncoder,
|
||||
)
|
||||
from esm.sdk.api import (
|
||||
ESM3InferenceClient,
|
||||
ESMProtein,
|
||||
@@ -29,7 +32,10 @@ from esm.sdk.api import (
|
||||
from esm.tokenization import TokenizerCollectionProtocol
|
||||
from esm.utils import encoding
|
||||
from esm.utils.constants import esm3 as C
|
||||
from esm.utils.constants.models import ESM3_OPEN_SMALL, normalize_model_name
|
||||
from esm.utils.constants.models import (
|
||||
ESM3_OPEN_SMALL,
|
||||
normalize_model_name,
|
||||
)
|
||||
from esm.utils.decoding import decode_protein_tensor
|
||||
from esm.utils.generation import (
|
||||
_batch_forward,
|
||||
@@ -44,7 +50,9 @@ from esm.utils.sampling import (
|
||||
get_default_sampling_config,
|
||||
validate_sampling_config,
|
||||
)
|
||||
from esm.utils.structure.affine3d import build_affine3d_from_coordinates
|
||||
from esm.utils.structure.affine3d import (
|
||||
build_affine3d_from_coordinates,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -12,7 +12,9 @@ from cloudpathlib import AnyPath
|
||||
|
||||
from esm.layers.regression_head import RegressionHead
|
||||
from esm.layers.transformer_stack import TransformerStack
|
||||
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
|
||||
from esm.tokenization.function_tokenizer import (
|
||||
InterProQuantizedTokenizer,
|
||||
)
|
||||
from esm.utils.constants import esm3 as C
|
||||
from esm.utils.misc import merge_annotations, merge_ranges
|
||||
from esm.utils.types import FunctionAnnotation
|
||||
|
||||
@@ -7,7 +7,10 @@ from esm.layers.structure_proj import Dim6RotStructureHead
|
||||
from esm.layers.transformer_stack import TransformerStack
|
||||
from esm.utils.constants import esm3 as C
|
||||
from esm.utils.misc import knn_graph
|
||||
from esm.utils.structure.affine3d import Affine3D, build_affine3d_from_coordinates
|
||||
from esm.utils.structure.affine3d import (
|
||||
Affine3D,
|
||||
build_affine3d_from_coordinates,
|
||||
)
|
||||
from esm.utils.structure.predicted_aligned_error import (
|
||||
compute_predicted_aligned_error,
|
||||
compute_tm,
|
||||
|
||||
@@ -6,8 +6,14 @@ import torch.nn as nn
|
||||
from esm.models.esm3 import ESM3
|
||||
from esm.models.esmc import ESMC
|
||||
from esm.models.function_decoder import FunctionTokenDecoder
|
||||
from esm.models.vqvae import StructureTokenDecoder, StructureTokenEncoder
|
||||
from esm.tokenization import get_esm3_model_tokenizers, get_esmc_model_tokenizers
|
||||
from esm.models.vqvae import (
|
||||
StructureTokenDecoder,
|
||||
StructureTokenEncoder,
|
||||
)
|
||||
from esm.tokenization import (
|
||||
get_esm3_model_tokenizers,
|
||||
get_esmc_model_tokenizers,
|
||||
)
|
||||
from esm.utils.constants.esm3 import data_root
|
||||
from esm.utils.constants.models import (
|
||||
ESM3_FUNCTION_DECODER_V0,
|
||||
|
||||
@@ -2,19 +2,27 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from copy import deepcopy
|
||||
from typing import Sequence
|
||||
from typing import List, Sequence
|
||||
|
||||
import attr
|
||||
import torch
|
||||
from attr import asdict, define
|
||||
|
||||
import esm.utils.constants.api as C
|
||||
from esm.tokenization import TokenizerCollectionProtocol, get_esm3_model_tokenizers
|
||||
from esm.tokenization import (
|
||||
TokenizerCollectionProtocol,
|
||||
get_esm3_model_tokenizers,
|
||||
)
|
||||
from esm.utils import encoding
|
||||
from esm.utils.constants.models import ESM3_OPEN_SMALL
|
||||
from esm.utils.misc import get_chainbreak_boundaries_from_sequence
|
||||
from esm.utils.misc import (
|
||||
get_chainbreak_boundaries_from_sequence,
|
||||
)
|
||||
from esm.utils.structure.protein_chain import ProteinChain
|
||||
from esm.utils.structure.protein_complex import SINGLE_LETTER_CHAIN_IDS, ProteinComplex
|
||||
from esm.utils.structure.protein_complex import (
|
||||
SINGLE_LETTER_CHAIN_IDS,
|
||||
ProteinComplex,
|
||||
)
|
||||
from esm.utils.types import FunctionAnnotation, PathOrBuffer
|
||||
|
||||
|
||||
@@ -35,6 +43,7 @@ class ESMProtein(ProteinType):
|
||||
plddt: torch.Tensor | None = None
|
||||
ptm: torch.Tensor | None = None
|
||||
|
||||
|
||||
# When calling EvolutionaryScale API, use this flag to disclose any
|
||||
# sequences that may potentially have concerns.
|
||||
# Such sequences may not go through standard safety filter for approved users.
|
||||
@@ -148,12 +157,35 @@ class ESMProtein(ProteinType):
|
||||
gt_chains = list(copy_annotations_from_ground_truth.chain_iter())
|
||||
else:
|
||||
gt_chains = None
|
||||
|
||||
# Expand pLDDT to match sequence length if needed, inserting NaN at chain breaks
|
||||
# This handles the case where the server doesn't include chain breaks in pLDDT
|
||||
# We should fix this in the server side.
|
||||
if self.plddt is not None and len(self.plddt) != len(self.sequence):
|
||||
# Only expand if there's a mismatch (likely due to chain breaks)
|
||||
if "|" in self.sequence:
|
||||
# Create expanded pLDDT with NaN at chain break positions
|
||||
expanded_plddt = torch.full((len(self.sequence),), float("nan"))
|
||||
plddt_idx = 0
|
||||
for i, aa in enumerate(self.sequence):
|
||||
if aa != "|":
|
||||
if plddt_idx < len(self.plddt):
|
||||
expanded_plddt[i] = self.plddt[plddt_idx]
|
||||
plddt_idx += 1
|
||||
plddt = expanded_plddt
|
||||
else:
|
||||
# Mismatch but no chain breaks - shouldn't happen but preserve original
|
||||
plddt = self.plddt
|
||||
else:
|
||||
plddt = self.plddt
|
||||
|
||||
pred_chains = []
|
||||
for i, (start, end) in enumerate(chain_boundaries):
|
||||
if i >= len(SINGLE_LETTER_CHAIN_IDS):
|
||||
raise ValueError(
|
||||
f"Too many chains to convert to ProteinComplex. The maximum number of chains is {len(SINGLE_LETTER_CHAIN_IDS)}"
|
||||
)
|
||||
|
||||
pred_chain = ProteinChain.from_atom37(
|
||||
atom37_positions=coords[start:end],
|
||||
sequence=self.sequence[start:end],
|
||||
@@ -161,7 +193,7 @@ class ESMProtein(ProteinType):
|
||||
if gt_chains is not None
|
||||
else SINGLE_LETTER_CHAIN_IDS[i],
|
||||
entity_id=gt_chains[i].entity_id if gt_chains is not None else None,
|
||||
confidence=self.plddt[start:end] if self.plddt is not None else None,
|
||||
confidence=plddt[start:end] if plddt is not None else None,
|
||||
)
|
||||
pred_chains.append(pred_chain)
|
||||
return ProteinComplex.from_chains(pred_chains)
|
||||
@@ -298,19 +330,14 @@ class GenerationConfig:
|
||||
self.temperature_annealing = True
|
||||
|
||||
|
||||
@define
|
||||
class MSA:
|
||||
# Paired MSA sequences.
|
||||
# One would typically compute these using, for example, ColabFold.
|
||||
sequences: list[str]
|
||||
|
||||
|
||||
@define
|
||||
class InverseFoldingConfig:
|
||||
invalid_ids: Sequence[int] = []
|
||||
temperature: float = 1.0
|
||||
|
||||
|
||||
|
||||
|
||||
## Low Level Endpoint Types
|
||||
@define
|
||||
class SamplingTrackConfig:
|
||||
@@ -375,6 +402,9 @@ class LogitsConfig:
|
||||
ith_hidden_layer: int = -1
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@define
|
||||
class LogitsOutput:
|
||||
logits: ForwardTrackData | None = None
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
import asyncio
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
|
||||
from esm.sdk.api import ESMProteinError
|
||||
from esm.sdk.retry import retry_decorator
|
||||
from esm.utils.decoding import assemble_message
|
||||
|
||||
|
||||
@@ -80,6 +84,10 @@ class _BaseForgeInferenceClient:
|
||||
headers = {**self.headers, **headers}
|
||||
if return_bytes:
|
||||
headers["return-bytes"] = "true"
|
||||
# __INTERNAL_BEGIN___
|
||||
if disable_cache:
|
||||
headers["X-Disable-Cache"] = "true"
|
||||
# __INTERNAL_END___
|
||||
return request, headers
|
||||
|
||||
def prepare_data(self, response, endpoint: str) -> dict[str, Any]:
|
||||
@@ -112,7 +120,11 @@ class _BaseForgeInferenceClient:
|
||||
):
|
||||
try:
|
||||
request, headers = self.prepare_request(
|
||||
request, potential_sequence_of_concern, return_bytes, headers
|
||||
request,
|
||||
potential_sequence_of_concern,
|
||||
return_bytes,
|
||||
disable_cache,
|
||||
headers,
|
||||
)
|
||||
response = await self.async_client.post(
|
||||
url=urljoin(self.url, f"/api/v1/{endpoint}"),
|
||||
@@ -142,7 +154,10 @@ class _BaseForgeInferenceClient:
|
||||
):
|
||||
try:
|
||||
request, headers = self.prepare_request(
|
||||
request, potential_sequence_of_concern, return_bytes, headers
|
||||
request,
|
||||
potential_sequence_of_concern,
|
||||
return_bytes,
|
||||
headers,
|
||||
)
|
||||
response = self.client.post(
|
||||
url=urljoin(self.url, f"/api/v1/{endpoint}"),
|
||||
@@ -160,3 +175,5 @@ class _BaseForgeInferenceClient:
|
||||
error_code=500,
|
||||
error_msg=f"Failed to submit request to {endpoint}. Error: {e}",
|
||||
)
|
||||
|
||||
|
||||
|
||||
161
esm/sdk/forge.py
161
esm/sdk/forge.py
@@ -1,13 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import pickle
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Sequence
|
||||
from typing import Any, Literal, Sequence, cast
|
||||
|
||||
import torch
|
||||
|
||||
from esm.sdk.api import (
|
||||
MSA,
|
||||
ESM3InferenceClient,
|
||||
ESMCInferenceClient,
|
||||
ESMProtein,
|
||||
@@ -19,14 +20,30 @@ from esm.sdk.api import (
|
||||
InverseFoldingConfig,
|
||||
LogitsConfig,
|
||||
LogitsOutput,
|
||||
ProteinChain,
|
||||
ProteinType,
|
||||
SamplingConfig,
|
||||
SamplingTrackConfig,
|
||||
)
|
||||
from esm.sdk.base_forge_client import _BaseForgeInferenceClient
|
||||
from esm.sdk.base_forge_client import (
|
||||
_BaseForgeInferenceClient,
|
||||
)
|
||||
from esm.sdk.retry import retry_decorator
|
||||
from esm.utils.constants.api import MIMETYPE_ES_PICKLE
|
||||
from esm.utils.misc import deserialize_tensors, maybe_list, maybe_tensor
|
||||
from esm.utils.misc import (
|
||||
deserialize_tensors,
|
||||
maybe_list,
|
||||
maybe_tensor,
|
||||
)
|
||||
from esm.utils.msa import MSA
|
||||
from esm.utils.structure.input_builder import (
|
||||
StructurePredictionInput,
|
||||
serialize_structure_prediction_input,
|
||||
)
|
||||
from esm.utils.structure.molecular_complex import (
|
||||
MolecularComplex,
|
||||
MolecularComplexResult,
|
||||
)
|
||||
from esm.utils.types import FunctionAnnotation
|
||||
|
||||
|
||||
@@ -36,10 +53,8 @@ def _list_to_function_annotations(l) -> list[FunctionAnnotation] | None:
|
||||
return [FunctionAnnotation(*t) for t in l]
|
||||
|
||||
|
||||
def _maybe_logits(data: dict[str, Any], track: str, return_bytes: bool = False):
|
||||
ret = data.get("logits", {}).get(track, None)
|
||||
# TODO(s22chan): just return this when removing return_bytes
|
||||
return ret if ret is None or not return_bytes else maybe_tensor(ret)
|
||||
def _maybe_logits(data: dict[str, Any], track: str):
|
||||
return maybe_tensor(data.get("logits", {}).get(track, None))
|
||||
|
||||
|
||||
def _maybe_b64_decode(obj, return_bytes: bool):
|
||||
@@ -93,9 +108,13 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _process_fold_request(sequence: str, model_name: str | None):
|
||||
def _process_fold_request(
|
||||
sequence: str,
|
||||
model_name: str | None,
|
||||
):
|
||||
request: dict[str, Any] = {"sequence": sequence}
|
||||
|
||||
|
||||
request["model"] = model_name
|
||||
|
||||
return request
|
||||
@@ -130,6 +149,7 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
||||
|
||||
return request
|
||||
|
||||
|
||||
async def _async_fetch_msa(self, sequence: str) -> MSA:
|
||||
print("Fetching MSA ... this may take a few minutes")
|
||||
# Accept both "|" and ":" as the chainbreak token.
|
||||
@@ -137,7 +157,7 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
||||
data = await self._async_post(
|
||||
"msa", request={}, params={"sequence": sequence, "use_env": False}
|
||||
)
|
||||
return MSA(sequences=data["msa"])
|
||||
return MSA.from_sequences(sequences=data["msa"])
|
||||
|
||||
def _fetch_msa(self, sequence: str) -> MSA:
|
||||
print("Fetching MSA ... this may take a few minutes")
|
||||
@@ -146,7 +166,7 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
||||
data = self._post(
|
||||
"msa", request={}, params={"sequence": sequence, "use_env": False}
|
||||
)
|
||||
return MSA(sequences=data["msa"])
|
||||
return MSA.from_sequences(sequences=data["msa"])
|
||||
|
||||
@retry_decorator
|
||||
async def async_fold(
|
||||
@@ -168,11 +188,15 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
||||
del potential_sequence_of_concern
|
||||
|
||||
request = self._process_fold_request(
|
||||
sequence, model_name if model_name is not None else self.model
|
||||
sequence,
|
||||
model_name if model_name is not None else self.model,
|
||||
)
|
||||
|
||||
try:
|
||||
data = await self._async_post("fold", request)
|
||||
data = await self._async_post(
|
||||
"fold",
|
||||
request,
|
||||
)
|
||||
except ESMProteinError as e:
|
||||
return e
|
||||
|
||||
@@ -199,16 +223,98 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
||||
del potential_sequence_of_concern
|
||||
|
||||
request = self._process_fold_request(
|
||||
sequence, model_name if model_name is not None else self.model
|
||||
sequence,
|
||||
model_name if model_name is not None else self.model,
|
||||
)
|
||||
|
||||
try:
|
||||
data = self._post("fold", request)
|
||||
data = self._post(
|
||||
"fold",
|
||||
request,
|
||||
)
|
||||
except ESMProteinError as e:
|
||||
return e
|
||||
|
||||
return self._process_fold_response(data, sequence)
|
||||
|
||||
@retry_decorator
|
||||
async def async_fold_all_atom(
|
||||
self,
|
||||
all_atom_input: StructurePredictionInput,
|
||||
model_name: str | None = None,
|
||||
) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError:
|
||||
"""Fold a molecular complex containing proteins, nucleic acids, and/or ligands.
|
||||
|
||||
Args:
|
||||
all_atom_input: StructurePredictionInput containing sequences for different molecule types
|
||||
model_name: Override the client level model name if needed
|
||||
"""
|
||||
request = self._process_fold_all_atom_request(
|
||||
all_atom_input,
|
||||
model_name if model_name is not None else self.model,
|
||||
)
|
||||
|
||||
try:
|
||||
data = await self._async_post(
|
||||
"fold_all_atom",
|
||||
request,
|
||||
)
|
||||
except ESMProteinError as e:
|
||||
return e
|
||||
|
||||
return self._process_fold_all_atom_response(data)
|
||||
|
||||
@retry_decorator
|
||||
def fold_all_atom(
|
||||
self,
|
||||
all_atom_input: StructurePredictionInput,
|
||||
model_name: str | None = None,
|
||||
) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError:
|
||||
"""Predict coordinates for a molecular complex containing proteins, dna, rna, and/or ligands.
|
||||
|
||||
Args:
|
||||
all_atom_input: StructurePredictionInput containing sequences for different molecule types
|
||||
model_name: Override the client level model name if needed
|
||||
"""
|
||||
request = self._process_fold_all_atom_request(
|
||||
all_atom_input,
|
||||
model_name if model_name is not None else self.model,
|
||||
)
|
||||
|
||||
try:
|
||||
data = self._post(
|
||||
"fold_all_atom",
|
||||
request,
|
||||
)
|
||||
except ESMProteinError as e:
|
||||
return e
|
||||
|
||||
return self._process_fold_all_atom_response(data)
|
||||
|
||||
@staticmethod
|
||||
def _process_fold_all_atom_request(
|
||||
all_atom_input: StructurePredictionInput,
|
||||
model_name: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
request: dict[str, Any] = {
|
||||
"all_atom_input": serialize_structure_prediction_input(all_atom_input),
|
||||
"model": model_name,
|
||||
}
|
||||
|
||||
|
||||
return request
|
||||
|
||||
@staticmethod
|
||||
def _process_fold_all_atom_response(data: dict[str, Any]) -> MolecularComplexResult:
|
||||
complex_data = data.get("complex")
|
||||
molecular_complex = MolecularComplex.from_state_dict(complex_data)
|
||||
return MolecularComplexResult(
|
||||
complex=molecular_complex,
|
||||
plddt=maybe_tensor(data.get("plddt"), convert_none_to_nan=True),
|
||||
ptm=data.get("ptm", None),
|
||||
distogram=maybe_tensor(data.get("distogram"), convert_none_to_nan=True),
|
||||
)
|
||||
|
||||
@retry_decorator
|
||||
async def async_inverse_fold(
|
||||
self,
|
||||
@@ -280,6 +386,7 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
||||
return ESMProtein(sequence=data["sequence"])
|
||||
|
||||
|
||||
|
||||
class ESM3ForgeInferenceClient(ESM3InferenceClient, _BaseForgeInferenceClient):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -602,19 +709,15 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient, _BaseForgeInferenceClient):
|
||||
|
||||
return LogitsOutput(
|
||||
logits=ForwardTrackData(
|
||||
sequence=_maybe_logits(data, "sequence", return_bytes),
|
||||
structure=_maybe_logits(data, "structure", return_bytes),
|
||||
secondary_structure=_maybe_logits(
|
||||
data, "secondary_structure", return_bytes
|
||||
),
|
||||
sasa=_maybe_logits(data, "sasa", return_bytes),
|
||||
function=_maybe_logits(data, "function", return_bytes),
|
||||
sequence=_maybe_logits(data, "sequence"),
|
||||
structure=_maybe_logits(data, "structure"),
|
||||
secondary_structure=_maybe_logits(data, "secondary_structure"),
|
||||
sasa=_maybe_logits(data, "sasa"),
|
||||
function=_maybe_logits(data, "function"),
|
||||
),
|
||||
embeddings=maybe_tensor(data["embeddings"]),
|
||||
mean_embedding=data["mean_embedding"],
|
||||
residue_annotation_logits=_maybe_logits(
|
||||
data, "residue_annotation", return_bytes
|
||||
),
|
||||
residue_annotation_logits=_maybe_logits(data, "residue_annotation"),
|
||||
hidden_states=maybe_tensor(data["hidden_states"]),
|
||||
mean_hidden_state=maybe_tensor(data["mean_hidden_state"]),
|
||||
)
|
||||
@@ -965,6 +1068,7 @@ class ESMCForgeInferenceClient(ESMCInferenceClient, _BaseForgeInferenceClient):
|
||||
"sequence": config.sequence,
|
||||
"return_embeddings": config.return_embeddings,
|
||||
"return_mean_embedding": config.return_mean_embedding,
|
||||
"return_mean_hidden_states": config.return_mean_hidden_states,
|
||||
"return_hidden_states": config.return_hidden_states,
|
||||
"ith_hidden_layer": config.ith_hidden_layer,
|
||||
}
|
||||
@@ -981,12 +1085,11 @@ class ESMCForgeInferenceClient(ESMCInferenceClient, _BaseForgeInferenceClient):
|
||||
data["hidden_states"] = _maybe_b64_decode(data["hidden_states"], return_bytes)
|
||||
|
||||
output = LogitsOutput(
|
||||
logits=ForwardTrackData(
|
||||
sequence=_maybe_logits(data, "sequence", return_bytes)
|
||||
),
|
||||
logits=ForwardTrackData(sequence=_maybe_logits(data, "sequence")),
|
||||
embeddings=maybe_tensor(data["embeddings"]),
|
||||
mean_embedding=data["mean_embedding"],
|
||||
hidden_states=maybe_tensor(data["hidden_states"]),
|
||||
mean_hidden_state=maybe_tensor(data["mean_hidden_state"]),
|
||||
)
|
||||
return output
|
||||
|
||||
@@ -1109,3 +1212,5 @@ class ESMCForgeInferenceClient(ESMCInferenceClient, _BaseForgeInferenceClient):
|
||||
raise NotImplementedError(
|
||||
f"Can not get underlying remote model {self.model} from a Forge client."
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -2,10 +2,9 @@ import inspect
|
||||
from contextvars import ContextVar
|
||||
from functools import wraps
|
||||
|
||||
import httpx
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
retry_if_exception,
|
||||
retry_if_result,
|
||||
stop_after_attempt,
|
||||
wait_incrementing,
|
||||
@@ -30,8 +29,12 @@ def retry_if_specific_error(exception):
|
||||
|
||||
|
||||
def log_retry_attempt(retry_state):
|
||||
try:
|
||||
outcome = retry_state.outcome.result()
|
||||
except Exception:
|
||||
outcome = retry_state.outcome.exception()
|
||||
print(
|
||||
f"Retrying... Attempt {retry_state.attempt_number} after {retry_state.next_action.sleep}s due to: {retry_state.outcome.result()}"
|
||||
f"Retrying... Attempt {retry_state.attempt_number} after {retry_state.next_action.sleep}s due to: {outcome}"
|
||||
)
|
||||
|
||||
|
||||
@@ -41,13 +44,18 @@ def retry_decorator(func):
|
||||
instance's retry settings.
|
||||
"""
|
||||
|
||||
def return_last_value(retry_state):
|
||||
"""Return the result of the last call attempt."""
|
||||
return retry_state.outcome.result()
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(instance, *args, **kwargs):
|
||||
if skip_retries_var.get():
|
||||
return await func(instance, *args, **kwargs)
|
||||
retry_decorator = retry(
|
||||
retry_error_callback=return_last_value,
|
||||
retry=retry_if_result(retry_if_specific_error)
|
||||
| retry_if_exception_type(httpx.ConnectTimeout), # ADDED
|
||||
| retry_if_exception(retry_if_specific_error),
|
||||
wait=wait_incrementing(
|
||||
increment=1, start=instance.min_retry_wait, max=instance.max_retry_wait
|
||||
),
|
||||
@@ -62,8 +70,9 @@ def retry_decorator(func):
|
||||
if skip_retries_var.get():
|
||||
return func(instance, *args, **kwargs)
|
||||
retry_decorator = retry(
|
||||
retry_error_callback=return_last_value,
|
||||
retry=retry_if_result(retry_if_specific_error)
|
||||
| retry_if_exception_type(httpx.ConnectTimeout), # ADDED
|
||||
| retry_if_exception(retry_if_specific_error),
|
||||
wait=wait_incrementing(
|
||||
increment=1, start=instance.min_retry_wait, max=instance.max_retry_wait
|
||||
),
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol
|
||||
|
||||
from esm.utils.constants.models import ESM3_OPEN_SMALL, normalize_model_name
|
||||
from esm.utils.constants.models import (
|
||||
ESM3_OPEN_SMALL,
|
||||
normalize_model_name,
|
||||
)
|
||||
|
||||
from .function_tokenizer import InterProQuantizedTokenizer
|
||||
from .residue_tokenizer import ResidueAnnotationsTokenizer
|
||||
|
||||
@@ -10,12 +10,24 @@ from esm.models.function_decoder import FunctionTokenDecoder
|
||||
from esm.models.vqvae import StructureTokenDecoder
|
||||
from esm.sdk.api import ESMProtein, ESMProteinTensor
|
||||
from esm.tokenization import TokenizerCollectionProtocol
|
||||
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
|
||||
from esm.tokenization.residue_tokenizer import ResidueAnnotationsTokenizer
|
||||
from esm.tokenization.sasa_tokenizer import SASADiscretizingTokenizer
|
||||
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
|
||||
from esm.tokenization.ss_tokenizer import SecondaryStructureTokenizer
|
||||
from esm.tokenization.structure_tokenizer import StructureTokenizer
|
||||
from esm.tokenization.function_tokenizer import (
|
||||
InterProQuantizedTokenizer,
|
||||
)
|
||||
from esm.tokenization.residue_tokenizer import (
|
||||
ResidueAnnotationsTokenizer,
|
||||
)
|
||||
from esm.tokenization.sasa_tokenizer import (
|
||||
SASADiscretizingTokenizer,
|
||||
)
|
||||
from esm.tokenization.sequence_tokenizer import (
|
||||
EsmSequenceTokenizer,
|
||||
)
|
||||
from esm.tokenization.ss_tokenizer import (
|
||||
SecondaryStructureTokenizer,
|
||||
)
|
||||
from esm.tokenization.structure_tokenizer import (
|
||||
StructureTokenizer,
|
||||
)
|
||||
from esm.tokenization.tokenizer_base import EsmTokenizerBase
|
||||
from esm.utils.constants import api as api_constants
|
||||
from esm.utils.constants import esm3 as C
|
||||
|
||||
@@ -7,13 +7,26 @@ from esm.models.vqvae import StructureTokenEncoder
|
||||
from esm.tokenization.function_tokenizer import (
|
||||
InterProQuantizedTokenizer as EsmFunctionTokenizer,
|
||||
)
|
||||
from esm.tokenization.residue_tokenizer import ResidueAnnotationsTokenizer
|
||||
from esm.tokenization.sasa_tokenizer import SASADiscretizingTokenizer
|
||||
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
|
||||
from esm.tokenization.ss_tokenizer import SecondaryStructureTokenizer
|
||||
from esm.tokenization.structure_tokenizer import StructureTokenizer
|
||||
|
||||
from esm.tokenization.residue_tokenizer import (
|
||||
ResidueAnnotationsTokenizer,
|
||||
)
|
||||
from esm.tokenization.sasa_tokenizer import (
|
||||
SASADiscretizingTokenizer,
|
||||
)
|
||||
from esm.tokenization.sequence_tokenizer import (
|
||||
EsmSequenceTokenizer,
|
||||
)
|
||||
from esm.tokenization.ss_tokenizer import (
|
||||
SecondaryStructureTokenizer,
|
||||
)
|
||||
from esm.tokenization.structure_tokenizer import (
|
||||
StructureTokenizer,
|
||||
)
|
||||
from esm.utils.constants import esm3 as C
|
||||
from esm.utils.function.encode_decode import encode_function_annotations
|
||||
from esm.utils.function.encode_decode import (
|
||||
encode_function_annotations,
|
||||
)
|
||||
from esm.utils.structure.protein_chain import ProteinChain
|
||||
from esm.utils.types import FunctionAnnotation
|
||||
|
||||
@@ -152,6 +165,8 @@ def tokenize_function_annotations(
|
||||
return function_tokens, residue_annotation_tokens
|
||||
|
||||
|
||||
|
||||
|
||||
# Tokenized Defaults
|
||||
def get_default_sequence_tokens(
|
||||
sequence_length: int, sequence_tokenizer: EsmSequenceTokenizer
|
||||
@@ -227,3 +242,5 @@ def get_default_residue_annotation_tokens(
|
||||
residue_annotation_tokens[0] = residue_annotation_tokenizer.bos_token_id
|
||||
residue_annotation_tokens[-1] = residue_annotation_tokenizer.eos_token_id
|
||||
return residue_annotation_tokens
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,10 @@ from typing import Any, Callable, List
|
||||
from tqdm import tqdm
|
||||
|
||||
from esm.sdk.api import ESMProteinError
|
||||
from esm.sdk.retry import retry_if_specific_error, skip_retries_var
|
||||
from esm.sdk.retry import (
|
||||
retry_if_specific_error,
|
||||
skip_retries_var,
|
||||
)
|
||||
|
||||
TQDM_BAR_FORMAT = (
|
||||
"{desc:<12}{percentage:3.0f}%|{bar:24}| {n_fmt}/{total_fmt} "
|
||||
|
||||
@@ -3,9 +3,16 @@ from typing import Sequence
|
||||
|
||||
import torch
|
||||
|
||||
from esm.models.function_decoder import FunctionTokenDecoder, merge_annotations
|
||||
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
|
||||
from esm.tokenization.residue_tokenizer import ResidueAnnotationsTokenizer
|
||||
from esm.models.function_decoder import (
|
||||
FunctionTokenDecoder,
|
||||
merge_annotations,
|
||||
)
|
||||
from esm.tokenization.function_tokenizer import (
|
||||
InterProQuantizedTokenizer,
|
||||
)
|
||||
from esm.tokenization.residue_tokenizer import (
|
||||
ResidueAnnotationsTokenizer,
|
||||
)
|
||||
from esm.utils.constants import esm3 as C
|
||||
from esm.utils.types import FunctionAnnotation
|
||||
|
||||
|
||||
@@ -19,8 +19,13 @@ from esm.sdk.api import (
|
||||
SamplingConfig,
|
||||
SamplingTrackConfig,
|
||||
)
|
||||
from esm.tokenization import EsmTokenizerBase, TokenizerCollectionProtocol
|
||||
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
|
||||
from esm.tokenization import (
|
||||
EsmTokenizerBase,
|
||||
TokenizerCollectionProtocol,
|
||||
)
|
||||
from esm.tokenization.function_tokenizer import (
|
||||
InterProQuantizedTokenizer,
|
||||
)
|
||||
from esm.utils.constants import esm3 as C
|
||||
from esm.utils.misc import stack_variable_length_tensors
|
||||
from esm.utils.noise_schedules import NOISE_SCHEDULE_REGISTRY
|
||||
@@ -43,7 +48,9 @@ def _trim_sequence_tensor_dataclass(o: Any, sequence_len: int):
|
||||
|
||||
sliced = {}
|
||||
for k, v in attr.asdict(o, recurse=False).items():
|
||||
if v is None:
|
||||
if k in ["mean_hidden_state", "mean_embedding"]:
|
||||
sliced[k] = v
|
||||
elif v is None:
|
||||
sliced[k] = None
|
||||
elif isinstance(v, torch.Tensor):
|
||||
# Trim padding.
|
||||
|
||||
@@ -1,7 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from dataclasses import is_dataclass
|
||||
from io import BytesIO
|
||||
from typing import Any, ContextManager, Sequence, TypeVar
|
||||
from typing import (
|
||||
Any,
|
||||
ContextManager,
|
||||
Generator,
|
||||
Iterable,
|
||||
Protocol,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
runtime_checkable,
|
||||
)
|
||||
from warnings import warn
|
||||
|
||||
import huggingface_hub
|
||||
@@ -18,6 +30,12 @@ MAX_SUPPORTED_DISTANCE = 1e6
|
||||
TSequence = TypeVar("TSequence", bound=Sequence)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Concatable(Protocol):
|
||||
@classmethod
|
||||
def concat(cls, objs: list[Concatable]) -> Concatable: ...
|
||||
|
||||
|
||||
def slice_python_object_as_numpy(
|
||||
obj: TSequence, idx: int | list[int] | slice | np.ndarray
|
||||
) -> TSequence:
|
||||
@@ -52,6 +70,37 @@ def slice_python_object_as_numpy(
|
||||
return sliced_obj # type: ignore
|
||||
|
||||
|
||||
def slice_any_object(
|
||||
obj: TSequence, idx: int | list[int] | slice | np.ndarray
|
||||
) -> TSequence:
|
||||
"""
|
||||
Slice a arbitrary object (like a list, string, or tuple) as if it was a numpy object. Similar to `slice_python_object_as_numpy`, but detects if it's a numpy array or Tensor and uses the existing slice method if so.
|
||||
|
||||
If the object is a dataclass, it will simply apply the index to the object, under the assumption that the object has correcty implemented numpy indexing.
|
||||
|
||||
Example:
|
||||
>>> obj = "ABCDE"
|
||||
>>> slice_any_object(obj, [1, 3, 4])
|
||||
"BDE"
|
||||
|
||||
>>> obj = np.array([1, 2, 3, 4, 5])
|
||||
>>> slice_any_object(obj, np.arange(5) < 3)
|
||||
np.array([1, 2, 3])
|
||||
|
||||
>>> obj = ProteinChain.from_rcsb("1a3a", "A")
|
||||
>>> slice_any_object(obj, np.arange(len(obj)) < 10)
|
||||
# ProteinChain w/ length 10
|
||||
|
||||
"""
|
||||
if isinstance(obj, (np.ndarray, torch.Tensor)):
|
||||
return obj[idx] # type: ignore
|
||||
elif is_dataclass(obj):
|
||||
# if passing a dataclass, assume it implements a custom slice
|
||||
return obj[idx] # type: ignore
|
||||
else:
|
||||
return slice_python_object_as_numpy(obj, idx)
|
||||
|
||||
|
||||
def rbf(values, v_min, v_max, n_bins=16):
|
||||
"""
|
||||
Returns RBF encodings in a new dimension at the end.
|
||||
@@ -298,6 +347,8 @@ def replace_inf(data):
|
||||
def maybe_tensor(x, convert_none_to_nan: bool = False) -> torch.Tensor | None:
|
||||
if x is None:
|
||||
return None
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x
|
||||
if isinstance(x, list) and all(isinstance(t, torch.Tensor) for t in x):
|
||||
return torch.stack(x)
|
||||
if convert_none_to_nan:
|
||||
@@ -357,3 +408,90 @@ def deserialize_tensors(b: bytes) -> Any:
|
||||
buf = BytesIO(zstd.ZSTD_uncompress(b))
|
||||
d = torch.load(buf, map_location="cpu", weights_only=False)
|
||||
return d
|
||||
|
||||
|
||||
def join_lists(
|
||||
lists: Sequence[Sequence[Any]], separator: Sequence[Any] | None = None
|
||||
) -> list[Any]:
|
||||
"""Joins multiple lists with separator element. Like str.join but for lists.
|
||||
|
||||
Example: [[1, 2], [3], [4]], separator=[0] -> [1, 2, 0, 3, 0, 4]
|
||||
|
||||
Args:
|
||||
lists: Lists of elements to chain
|
||||
separator: separators to intsert between chained output.
|
||||
Returns:
|
||||
Joined lists.
|
||||
"""
|
||||
if not lists:
|
||||
return []
|
||||
joined = []
|
||||
joined.extend(lists[0])
|
||||
for l in lists[1:]:
|
||||
if separator:
|
||||
joined.extend(separator)
|
||||
joined.extend(l)
|
||||
return joined
|
||||
|
||||
|
||||
def iterate_with_intermediate(
|
||||
lists: Iterable, intermediate
|
||||
) -> Generator[Any, None, None]:
|
||||
"""
|
||||
Iterate over the iterable, yielding the intermediate value between
|
||||
every element of the intermediate. Useful for joining objects with
|
||||
separator tokens.
|
||||
"""
|
||||
it = iter(lists)
|
||||
yield next(it)
|
||||
for l in it:
|
||||
yield intermediate
|
||||
yield l
|
||||
|
||||
|
||||
def concat_objects(objs: Sequence[Any], separator: Any | None = None):
|
||||
"""
|
||||
Concat objects with each other using a separator token.
|
||||
|
||||
Supports:
|
||||
- Concatable (objects that implement `concat` classmethod)
|
||||
- strings
|
||||
- lists
|
||||
- numpy arrays
|
||||
- torch Tensors
|
||||
|
||||
Example:
|
||||
>>> foo = "abc"
|
||||
>>> bar = "def"
|
||||
>>> concat_objects([foo, bar], "|")
|
||||
"abc|def"
|
||||
"""
|
||||
match objs[0]:
|
||||
case Concatable():
|
||||
return objs[0].__class__.concat(objs) # type: ignore
|
||||
case str():
|
||||
assert isinstance(
|
||||
separator, str
|
||||
), "Trying to join strings but separator is not a string"
|
||||
return separator.join(objs)
|
||||
case list():
|
||||
if separator is not None:
|
||||
return join_lists(objs, [separator])
|
||||
else:
|
||||
return join_lists(objs)
|
||||
case np.ndarray():
|
||||
if separator is not None:
|
||||
return np.concatenate(
|
||||
list(iterate_with_intermediate(objs, np.array([separator])))
|
||||
)
|
||||
else:
|
||||
return np.concatenate(objs)
|
||||
case torch.Tensor():
|
||||
if separator is not None:
|
||||
return torch.cat(
|
||||
list(iterate_with_intermediate(objs, torch.tensor([separator])))
|
||||
)
|
||||
else:
|
||||
return torch.cat(objs) # type: ignore
|
||||
case _:
|
||||
raise TypeError(type(objs[0]))
|
||||
|
||||
7
esm/utils/msa/__init__.py
Normal file
7
esm/utils/msa/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from esm.utils.msa.msa import (
|
||||
MSA,
|
||||
FastMSA,
|
||||
remove_insertions_from_sequence,
|
||||
)
|
||||
|
||||
__all__ = ["MSA", "FastMSA", "remove_insertions_from_sequence"]
|
||||
79
esm/utils/msa/filter_sequences.py
Normal file
79
esm/utils/msa/filter_sequences.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from scipy.spatial.distance import cdist
|
||||
|
||||
from esm.utils.system import run_subprocess_with_errorcheck
|
||||
|
||||
|
||||
def greedy_select_indices(array, num_seqs: int, mode: str = "max") -> list[int]:
|
||||
"""Greedily select sequences that either maximize or minimize hamming distance.
|
||||
|
||||
Algorithm proposed in the MSA Transformer paper. Starting from the query sequence,
|
||||
iteratively add sequences to the list with the maximum (minimum) average Hamming
|
||||
distance to the existing set of sequences.
|
||||
|
||||
Args:
|
||||
array (np.ndarray): Character array representing the sequences in the MSA
|
||||
num_seqs (int): Number of sequences to select.
|
||||
mode (str): Whether to maximize or minimize diversity. DO NOT pick 'min' unless
|
||||
you're doing it to prove a point for a paper.
|
||||
|
||||
Returns:
|
||||
list[int]: List of indices to select from the array
|
||||
"""
|
||||
assert mode in ("max", "min")
|
||||
depth = array.shape[0]
|
||||
if depth <= num_seqs:
|
||||
return list(range(depth))
|
||||
array = array.view(np.uint8)
|
||||
|
||||
optfunc = np.argmax if mode == "max" else np.argmin
|
||||
all_indices = np.arange(depth)
|
||||
indices = [0]
|
||||
pairwise_distances = np.zeros((0, depth))
|
||||
for _ in range(num_seqs - 1):
|
||||
dist = cdist(array[indices[-1:]], array, "hamming")
|
||||
pairwise_distances = np.concatenate([pairwise_distances, dist])
|
||||
shifted_distance = np.delete(pairwise_distances, indices, axis=1).mean(0)
|
||||
shifted_index = optfunc(shifted_distance)
|
||||
index = np.delete(all_indices, indices)[shifted_index]
|
||||
indices.append(index)
|
||||
indices = sorted(indices)
|
||||
return indices
|
||||
|
||||
|
||||
def hhfilter(
|
||||
sequences: list[str],
|
||||
seqid: int = 90,
|
||||
diff: int = 0,
|
||||
cov: int = 0,
|
||||
qid: int = 0,
|
||||
qsc: float = -20.0,
|
||||
binary: str = "hhfilter",
|
||||
) -> list[int]:
|
||||
with tempfile.TemporaryDirectory(dir="/dev/shm") as tempdirname:
|
||||
tempdir = Path(tempdirname)
|
||||
fasta_file = tempdir / "input.fasta"
|
||||
fasta_file.write_text(
|
||||
"\n".join(f">{i}\n{seq}" for i, seq in enumerate(sequences))
|
||||
)
|
||||
output_file = tempdir / "output.fasta"
|
||||
command = " ".join(
|
||||
[
|
||||
f"{binary}",
|
||||
f"-i {fasta_file}",
|
||||
"-M a3m",
|
||||
f"-o {output_file}",
|
||||
f"-id {seqid}",
|
||||
f"-diff {diff}",
|
||||
f"-cov {cov}",
|
||||
f"-qid {qid}",
|
||||
f"-qsc {qsc}",
|
||||
]
|
||||
).split(" ")
|
||||
run_subprocess_with_errorcheck(command, capture_output=True)
|
||||
with output_file.open() as f:
|
||||
indices = [int(line[1:].strip()) for line in f if line.startswith(">")]
|
||||
return indices
|
||||
507
esm/utils/msa/msa.py
Normal file
507
esm/utils/msa/msa.py
Normal file
@@ -0,0 +1,507 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import string
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from itertools import islice
|
||||
from typing import Sequence
|
||||
|
||||
import numpy as np
|
||||
from Bio import SeqIO
|
||||
from scipy.spatial.distance import cdist
|
||||
|
||||
from esm.utils.misc import slice_any_object
|
||||
from esm.utils.msa.filter_sequences import (
|
||||
greedy_select_indices,
|
||||
hhfilter,
|
||||
)
|
||||
from esm.utils.parsing import (
|
||||
FastaEntry,
|
||||
read_sequences,
|
||||
write_sequences,
|
||||
)
|
||||
from esm.utils.sequential_dataclass import SequentialDataclass
|
||||
from esm.utils.system import PathOrBuffer
|
||||
|
||||
REMOVE_LOWERCASE_TRANSLATION = str.maketrans(dict.fromkeys(string.ascii_lowercase))
|
||||
|
||||
|
||||
def remove_insertions_from_sequence(seq: str) -> str:
|
||||
return seq.translate(REMOVE_LOWERCASE_TRANSLATION)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MSA(SequentialDataclass):
|
||||
"""Object-oriented interface to an MSA.
|
||||
|
||||
Args:
|
||||
sequences (list[str]): List of protein sequences
|
||||
headers (list[str]): List of headers describing the sequences
|
||||
|
||||
"""
|
||||
|
||||
entries: list[FastaEntry]
|
||||
|
||||
@cached_property
|
||||
def sequences(self) -> list[str]:
|
||||
return [entry.sequence for entry in self.entries]
|
||||
|
||||
@cached_property
|
||||
def headers(self) -> list[str]:
|
||||
return [entry.header for entry in self.entries]
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"MSA({self.entries[0].header}: Depth={self.depth}, Length={self.seqlen})"
|
||||
)
|
||||
|
||||
def to_fast_msa(self) -> FastMSA:
|
||||
return FastMSA(self.array, self.headers)
|
||||
|
||||
@classmethod
|
||||
def from_a3m(
|
||||
cls,
|
||||
path: PathOrBuffer,
|
||||
remove_insertions: bool = True,
|
||||
max_sequences: int | None = None,
|
||||
) -> MSA:
|
||||
entries = []
|
||||
for header, seq in islice(read_sequences(path), max_sequences):
|
||||
if remove_insertions:
|
||||
seq = remove_insertions_from_sequence(seq)
|
||||
if entries:
|
||||
assert (
|
||||
len(seq) == len(entries[0].sequence)
|
||||
), f"Sequence length mismatch. Expected: {len(entries[0].sequence)}, Received: {len(seq)}"
|
||||
entries.append(FastaEntry(header, seq))
|
||||
return cls(entries)
|
||||
|
||||
def to_a3m(self, path: PathOrBuffer) -> None:
|
||||
write_sequences(self.entries, path)
|
||||
|
||||
@classmethod
|
||||
def from_stockholm(
|
||||
cls,
|
||||
path: PathOrBuffer,
|
||||
remove_insertions: bool = True,
|
||||
max_sequences: int | None = None,
|
||||
) -> MSA:
|
||||
entries = []
|
||||
for record in islice(SeqIO.parse(path, "stockholm"), max_sequences):
|
||||
header = f"{record.id} {record.description}"
|
||||
seq = str(record.seq)
|
||||
if entries:
|
||||
assert (
|
||||
len(seq) == len(entries[0].sequence)
|
||||
), f"Sequence length mismatch. Expected: {len(entries[0].sequence)}, Received: {len(seq)}"
|
||||
entries.append(FastaEntry(header, seq))
|
||||
msa = cls(entries)
|
||||
if remove_insertions:
|
||||
keep_inds = [i for i, aa in enumerate(msa.query) if aa != "-"]
|
||||
msa = msa.select_positions(keep_inds)
|
||||
return msa
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
version = 1
|
||||
version_bytes = version.to_bytes(1, "little")
|
||||
seqlen_bytes = self.seqlen.to_bytes(4, "little")
|
||||
depth_bytes = self.depth.to_bytes(4, "little")
|
||||
array_bytes = self.array.tobytes()
|
||||
header_bytes = "\n".join(entry.header for entry in self.entries).encode()
|
||||
all_bytes = (
|
||||
version_bytes + seqlen_bytes + depth_bytes + array_bytes + header_bytes
|
||||
)
|
||||
return all_bytes
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> MSA:
|
||||
version_bytes, seqlen_bytes, depth_bytes, data = (
|
||||
data[:1],
|
||||
data[1:5],
|
||||
data[5:9],
|
||||
data[9:],
|
||||
)
|
||||
version = int.from_bytes(version_bytes, "little")
|
||||
if version != 1:
|
||||
raise ValueError(f"Unsupported version: {version}")
|
||||
seqlen = int.from_bytes(seqlen_bytes, "little")
|
||||
depth = int.from_bytes(depth_bytes, "little")
|
||||
array_bytes, header_bytes = data[: seqlen * depth], data[seqlen * depth :]
|
||||
array = np.frombuffer(array_bytes, dtype="|S1")
|
||||
array = array.reshape(depth, seqlen)
|
||||
headers = header_bytes.decode().split("\n")
|
||||
# Sometimes the separation is two newlines, which results in an empty header.
|
||||
headers = [header for header in headers if header]
|
||||
entries = [
|
||||
FastaEntry(header, b"".join(row).decode())
|
||||
for header, row in zip(headers, array)
|
||||
]
|
||||
return cls(entries)
|
||||
|
||||
# TODO(jmaccarl): set remove_insertions to True by default here to match other utils
|
||||
@classmethod
|
||||
def from_sequences(
|
||||
cls, sequences: list[str], remove_insertions: bool = False
|
||||
) -> MSA:
|
||||
if remove_insertions:
|
||||
entries = [
|
||||
FastaEntry("", remove_insertions_from_sequence(seq))
|
||||
for seq in sequences
|
||||
]
|
||||
else:
|
||||
entries = [FastaEntry("", seq) for seq in sequences]
|
||||
return cls(entries)
|
||||
|
||||
def to_sequence_bytes(self) -> bytes:
|
||||
"""Stores ONLY SEQUENCES in array format as bytes. Header information will be lost."""
|
||||
seqlen_bytes = self.seqlen.to_bytes(4, "little")
|
||||
array_bytes = self.array.tobytes()
|
||||
all_bytes = seqlen_bytes + array_bytes
|
||||
return all_bytes
|
||||
|
||||
@classmethod
|
||||
def from_sequence_bytes(cls, data: bytes) -> MSA:
|
||||
seqlen_bytes, array_bytes = data[:4], data[4:]
|
||||
seqlen = int.from_bytes(seqlen_bytes, "little")
|
||||
array = np.frombuffer(array_bytes, dtype="|S1")
|
||||
array = array.reshape(-1, seqlen)
|
||||
entries = [FastaEntry("", b"".join(row).decode()) for row in array]
|
||||
return cls(entries)
|
||||
|
||||
@property
|
||||
def depth(self) -> int:
|
||||
return len(self.entries)
|
||||
|
||||
@property
|
||||
def seqlen(self) -> int:
|
||||
return len(self.entries[0].sequence)
|
||||
|
||||
@cached_property
|
||||
def array(self) -> np.ndarray:
|
||||
return np.array([list(seq) for seq in self.sequences], dtype="|S1")
|
||||
|
||||
@property
|
||||
def query(self) -> str:
|
||||
return self.entries[0].sequence
|
||||
|
||||
def select_sequences(self, indices: Sequence[int] | np.ndarray) -> MSA:
|
||||
"""Subselect rows of the MSA."""
|
||||
entries = [self.entries[idx] for idx in indices]
|
||||
return dataclasses.replace(self, entries=entries)
|
||||
|
||||
def select_positions(self, indices: Sequence[int] | np.ndarray) -> MSA:
|
||||
"""Subselect columns of the MSA."""
|
||||
entries = [
|
||||
FastaEntry(header, "".join(seq[idx] for idx in indices))
|
||||
for header, seq in self.entries
|
||||
]
|
||||
return dataclasses.replace(self, entries=entries)
|
||||
|
||||
def __getitem__(self, indices: int | list[int] | slice | np.ndarray):
|
||||
if isinstance(indices, int):
|
||||
indices = [indices]
|
||||
|
||||
entries = [
|
||||
FastaEntry(header, slice_any_object(seq, indices))
|
||||
for header, seq in self.entries
|
||||
]
|
||||
return dataclasses.replace(self, entries=entries)
|
||||
|
||||
def __len__(self):
|
||||
return self.seqlen
|
||||
|
||||
def greedy_select(self, num_seqs: int, mode: str = "max") -> MSA:
|
||||
"""Greedily select sequences that either maximize or minimize hamming distance.
|
||||
|
||||
Algorithm proposed in the MSA Transformer paper. Starting from the query sequence,
|
||||
iteratively add sequences to the list with the maximum (minimum) average Hamming
|
||||
distance to the existing set of sequences.
|
||||
|
||||
Args:
|
||||
num_seqs (int): Number of sequences to select.
|
||||
mode (str): Whether to maximize or minimize diversity. DO NOT pick 'min' unless
|
||||
you're doing it to prove a point for a paper.
|
||||
|
||||
Returns:
|
||||
MSA object w/ subselected sequences.
|
||||
"""
|
||||
assert mode in ("max", "min")
|
||||
if self.depth <= num_seqs:
|
||||
return self
|
||||
|
||||
indices = greedy_select_indices(self.array, num_seqs, mode)
|
||||
return self.select_sequences(indices)
|
||||
|
||||
def hhfilter(
|
||||
self,
|
||||
seqid: int = 90,
|
||||
diff: int = 0,
|
||||
cov: int = 0,
|
||||
qid: int = 0,
|
||||
qsc: float = -20.0,
|
||||
binary: str = "hhfilter",
|
||||
) -> MSA:
|
||||
"""Apply hhfilter to the sequences in the MSA and return a filtered MSA."""
|
||||
|
||||
indices = hhfilter(
|
||||
self.sequences,
|
||||
seqid=seqid,
|
||||
diff=diff,
|
||||
cov=cov,
|
||||
qid=qid,
|
||||
qsc=qsc,
|
||||
binary=binary,
|
||||
)
|
||||
return self.select_sequences(indices)
|
||||
|
||||
def select_random_sequences(self, num_seqs: int) -> MSA:
|
||||
"""Uses random sampling to subselect sequences from the MSA. Always
|
||||
keeps the query sequence.
|
||||
"""
|
||||
if num_seqs >= self.depth:
|
||||
return self
|
||||
|
||||
# Subselect random, always keeping the query sequence.
|
||||
indices = np.sort(
|
||||
np.append(
|
||||
0, np.random.choice(self.depth - 1, num_seqs - 1, replace=False) + 1
|
||||
)
|
||||
)
|
||||
msa = self.select_sequences(indices) # type: ignore
|
||||
return msa
|
||||
|
||||
def select_diverse_sequences(self, num_seqs: int) -> MSA:
|
||||
"""Applies hhfilter to select ~num_seqs sequences, then uses random sampling
|
||||
to subselect if necessary.
|
||||
"""
|
||||
if num_seqs >= self.depth:
|
||||
return self
|
||||
|
||||
msa = self.hhfilter(diff=num_seqs)
|
||||
if num_seqs < msa.depth:
|
||||
msa = msa.select_random_sequences(num_seqs)
|
||||
return msa
|
||||
|
||||
def pad_to_depth(self, depth: int) -> MSA:
|
||||
if depth < self.depth:
|
||||
raise ValueError(f"Cannot pad to depth {depth} when depth is {self.depth}")
|
||||
elif depth == self.depth:
|
||||
return self
|
||||
|
||||
num_to_add = depth - self.depth
|
||||
extra_entries = [FastaEntry("", "-" * self.seqlen) for _ in range(num_to_add)]
|
||||
return dataclasses.replace(self, entries=self.entries + extra_entries)
|
||||
|
||||
@classmethod
|
||||
def stack(
|
||||
cls, msas: Sequence[MSA], remove_query_from_later_msas: bool = True
|
||||
) -> MSA:
|
||||
"""Stack a series of MSAs. Optionally remove the query from msas after the first."""
|
||||
all_entries = []
|
||||
for i, msa in enumerate(msas):
|
||||
entries = msa.entries
|
||||
if i > 0 and remove_query_from_later_msas:
|
||||
entries = entries[1:]
|
||||
all_entries.extend(entries)
|
||||
return cls(entries=all_entries)
|
||||
|
||||
@cached_property
|
||||
def seqid(self) -> np.ndarray:
|
||||
array = self.array.view(np.uint8)
|
||||
seqid = 1 - cdist(array[0][None], array, "hamming")
|
||||
return seqid[0]
|
||||
|
||||
@classmethod
|
||||
def concat(
|
||||
cls,
|
||||
msas: Sequence[MSA],
|
||||
join_token: str | None = "|",
|
||||
allow_depth_mismatch: bool = False,
|
||||
) -> MSA:
|
||||
"""Concatenate a series of MSAs horizontally, along the sequence dimension."""
|
||||
if not msas:
|
||||
raise ValueError("Cannot concatenate an empty list of MSAs")
|
||||
msa_depths = [msa.depth for msa in msas]
|
||||
if len(set(msa_depths)) != 1:
|
||||
if not allow_depth_mismatch:
|
||||
raise ValueError("Depth mismatch in concatenating MSAs")
|
||||
else:
|
||||
max_depth = max(msa_depths)
|
||||
msas = [msa.pad_to_depth(max_depth) for msa in msas]
|
||||
headers = [
|
||||
"|".join([str(h) for h in headers])
|
||||
for headers in zip(*(msa.headers for msa in msas))
|
||||
]
|
||||
|
||||
if join_token is None:
|
||||
join_token = ""
|
||||
|
||||
seqs = [join_token.join(vals) for vals in zip(*(msa.sequences for msa in msas))]
|
||||
entries = [FastaEntry(header, seq) for header, seq in zip(headers, seqs)]
|
||||
return cls(entries)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FastMSA(SequentialDataclass):
|
||||
"""Object-oriented interface to an MSA stored as a numpy uint8 array."""
|
||||
|
||||
array: np.ndarray
|
||||
headers: list[str] | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.headers is not None:
|
||||
assert (
|
||||
len(self.headers) == self.depth
|
||||
), "Number of headers must match depth."
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> FastMSA:
|
||||
version_bytes, seqlen_bytes, depth_bytes, data = (
|
||||
data[:1],
|
||||
data[1:5],
|
||||
data[5:9],
|
||||
data[9:],
|
||||
)
|
||||
version = int.from_bytes(version_bytes, "little")
|
||||
if version != 1:
|
||||
raise ValueError(f"Unsupported version: {version}")
|
||||
seqlen = int.from_bytes(seqlen_bytes, "little")
|
||||
depth = int.from_bytes(depth_bytes, "little")
|
||||
array_bytes, header_bytes = data[: seqlen * depth], data[seqlen * depth :]
|
||||
array = np.frombuffer(array_bytes, dtype="|S1")
|
||||
array = array.reshape(depth, seqlen)
|
||||
headers = header_bytes.decode().split("\n")
|
||||
# Sometimes the separation is two newlines, which results in an empty header.
|
||||
headers = [header for header in headers if header]
|
||||
return cls(array, headers)
|
||||
|
||||
@classmethod
|
||||
def from_sequence_bytes(cls, data: bytes) -> FastMSA:
|
||||
seqlen_bytes, array_bytes = data[:4], data[4:]
|
||||
seqlen = int.from_bytes(seqlen_bytes, "little")
|
||||
array = np.frombuffer(array_bytes, dtype="|S1")
|
||||
array = array.reshape(-1, seqlen)
|
||||
return cls(array)
|
||||
|
||||
@property
|
||||
def depth(self) -> int:
|
||||
return self.array.shape[0]
|
||||
|
||||
@property
|
||||
def seqlen(self) -> int:
|
||||
return self.array.shape[1]
|
||||
|
||||
def __len__(self):
|
||||
return self.seqlen
|
||||
|
||||
def __getitem__(self, indices: int | list[int] | slice | np.ndarray):
|
||||
if isinstance(indices, int):
|
||||
indices = [indices]
|
||||
|
||||
return dataclasses.replace(self, array=self.array[:, indices])
|
||||
|
||||
def select_sequences(self, indices: Sequence[int] | np.ndarray) -> FastMSA:
|
||||
"""Subselect rows of the MSA."""
|
||||
array = self.array[indices]
|
||||
headers = (
|
||||
[self.headers[idx] for idx in indices] if self.headers is not None else None
|
||||
)
|
||||
return dataclasses.replace(self, array=array, headers=headers)
|
||||
|
||||
def select_random_sequences(self, num_seqs: int) -> FastMSA:
|
||||
"""Uses random sampling to subselect sequences from the MSA. Always
|
||||
keeps the query sequence.
|
||||
"""
|
||||
if num_seqs >= self.depth:
|
||||
return self
|
||||
|
||||
# Subselect random, always keeping the query sequence.
|
||||
indices = np.sort(
|
||||
np.append(
|
||||
0, np.random.choice(self.depth - 1, num_seqs - 1, replace=False) + 1
|
||||
)
|
||||
)
|
||||
msa = self.select_sequences(indices) # type: ignore
|
||||
return msa
|
||||
|
||||
def pad_to_depth(self, depth: int) -> FastMSA:
|
||||
if depth < self.depth:
|
||||
raise ValueError(f"Cannot pad to depth {depth} when depth is {self.depth}")
|
||||
elif depth == self.depth:
|
||||
return self
|
||||
|
||||
num_to_add = depth - self.depth
|
||||
array = np.pad(
|
||||
self.array,
|
||||
[(0, num_to_add), (0, 0)],
|
||||
constant_values=ord("-") if self.array.dtype == np.uint8 else b"-",
|
||||
)
|
||||
headers = self.headers
|
||||
if headers is not None:
|
||||
headers = headers + [""] * num_to_add
|
||||
return dataclasses.replace(self, array=array, headers=headers)
|
||||
|
||||
@classmethod
|
||||
def concat(
|
||||
cls,
|
||||
msas: Sequence[FastMSA],
|
||||
join_token: str | None = None,
|
||||
allow_depth_mismatch: bool = False,
|
||||
) -> FastMSA:
|
||||
"""Concatenate a series of MSAs horizontally, along the sequence dimension."""
|
||||
if not msas:
|
||||
raise ValueError("Cannot concatenate an empty list of MSAs")
|
||||
if join_token is not None and join_token != "":
|
||||
raise NotImplementedError("join_token is not supported for FastMSA")
|
||||
|
||||
msa_depths = [msa.depth for msa in msas]
|
||||
if len(set(msa_depths)) != 1:
|
||||
if not allow_depth_mismatch:
|
||||
raise ValueError("Depth mismatch in concatenating MSAs")
|
||||
else:
|
||||
max_depth = max(msa_depths)
|
||||
msas = [msa.pad_to_depth(max_depth) for msa in msas]
|
||||
headers = [
|
||||
"|".join([str(h) for h in headers])
|
||||
for headers in zip(
|
||||
*(
|
||||
msa.headers if msa.headers is not None else [""] * msa.depth
|
||||
for msa in msas
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
array = np.concatenate([msa.array for msa in msas], axis=1)
|
||||
return cls(array, headers)
|
||||
|
||||
def to_msa(self) -> MSA:
|
||||
headers = (
|
||||
self.headers
|
||||
if self.headers is not None
|
||||
else [f"seq{i}" for i in range(self.depth)]
|
||||
)
|
||||
entries = [
|
||||
FastaEntry(header, b"".join(row).decode())
|
||||
for header, row in zip(headers, self.array)
|
||||
]
|
||||
return MSA(entries)
|
||||
|
||||
@classmethod
|
||||
def stack(
|
||||
cls, msas: Sequence[FastMSA], remove_query_from_later_msas: bool = True
|
||||
) -> FastMSA:
|
||||
"""Stack a series of MSAs. Optionally remove the query from msas after the first."""
|
||||
arrays = []
|
||||
all_headers = []
|
||||
for i, msa in enumerate(msas):
|
||||
array = msa.array
|
||||
headers = msa.headers
|
||||
if i > 0 and remove_query_from_later_msas:
|
||||
array = array[1:]
|
||||
if headers is not None:
|
||||
headers = headers[1:]
|
||||
arrays.append(array)
|
||||
if headers is not None:
|
||||
all_headers.extend(headers)
|
||||
return cls(np.concatenate(arrays, axis=0), all_headers)
|
||||
83
esm/utils/parsing.py
Normal file
83
esm/utils/parsing.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import io
|
||||
from pathlib import Path
|
||||
from typing import Generator, Iterable, NamedTuple
|
||||
|
||||
PathOrBuffer = str | Path | io.TextIOBase
|
||||
FastaEntry = NamedTuple("FastaEntry", [("header", str), ("sequence", str)])
|
||||
|
||||
|
||||
def parse_fasta(fasta_string: str) -> Generator[FastaEntry, None, None]:
|
||||
"""
|
||||
Parses a fasta file and yields FastaEntry objects
|
||||
|
||||
Args:
|
||||
fasta_string: The fasta file as a string
|
||||
Returns:
|
||||
A generator of FastaEntry objects
|
||||
"""
|
||||
header = None
|
||||
seq = []
|
||||
num_sequences = 0
|
||||
for line in fasta_string.splitlines():
|
||||
if not line or line[0] == "#":
|
||||
continue
|
||||
if line.startswith(">"):
|
||||
if header is not None:
|
||||
yield FastaEntry(header, "".join(seq))
|
||||
seq = []
|
||||
header = line[1:].strip()
|
||||
else:
|
||||
seq.append(line)
|
||||
if header is not None:
|
||||
num_sequences += 1
|
||||
yield FastaEntry(header, "".join(seq))
|
||||
|
||||
if num_sequences == 0:
|
||||
raise ValueError("Found no sequences in input")
|
||||
|
||||
|
||||
def read_sequences(path: PathOrBuffer) -> Generator[FastaEntry, None, None]:
|
||||
# Uses duck typing to try and call the right method
|
||||
# Doesn't use explicit isinstance check to support
|
||||
# inputs that are not explicitly str/Path/TextIOBase but
|
||||
# may support similar functionality
|
||||
data = None # type: ignore
|
||||
try:
|
||||
if str(path).endswith(".gz"):
|
||||
import gzip
|
||||
|
||||
data = gzip.open(path, "rt") # type: ignore
|
||||
else:
|
||||
try:
|
||||
data = open(path) # type: ignore
|
||||
except TypeError:
|
||||
data: io.TextIOBase = path # type: ignore
|
||||
|
||||
yield from parse_fasta(data.read())
|
||||
finally:
|
||||
if data is not None:
|
||||
data.close()
|
||||
|
||||
|
||||
def read_first_sequence(path: PathOrBuffer) -> FastaEntry:
|
||||
return next(iter(read_sequences(path)))
|
||||
|
||||
|
||||
def write_sequences(sequences: Iterable[tuple[str, str]], path: PathOrBuffer) -> None:
|
||||
needs_closing = False
|
||||
handle = None
|
||||
try:
|
||||
try:
|
||||
handle = open(path, "w") # type: ignore
|
||||
needs_closing = True
|
||||
except TypeError:
|
||||
handle = path
|
||||
has_prev = False
|
||||
for header, seq in sequences:
|
||||
if has_prev:
|
||||
handle.write("\n") # type: ignore
|
||||
handle.write(f">{header}\n{seq}") # type: ignore
|
||||
has_prev = True
|
||||
finally:
|
||||
if needs_closing:
|
||||
handle.close() # type: ignore
|
||||
@@ -5,9 +5,18 @@ import attr
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from esm.sdk.api import ESMProteinTensor, SamplingConfig, SamplingTrackConfig
|
||||
from esm.tokenization import TokenizerCollectionProtocol, get_invalid_tokenizer_ids
|
||||
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
|
||||
from esm.sdk.api import (
|
||||
ESMProteinTensor,
|
||||
SamplingConfig,
|
||||
SamplingTrackConfig,
|
||||
)
|
||||
from esm.tokenization import (
|
||||
TokenizerCollectionProtocol,
|
||||
get_invalid_tokenizer_ids,
|
||||
)
|
||||
from esm.tokenization.function_tokenizer import (
|
||||
InterProQuantizedTokenizer,
|
||||
)
|
||||
from esm.utils.constants.esm3 import (
|
||||
MAX_RESIDUE_ANNOTATIONS,
|
||||
SASA_DISCRETIZATION_BOUNDARIES,
|
||||
|
||||
157
esm/utils/sequential_dataclass.py
Normal file
157
esm/utils/sequential_dataclass.py
Normal file
@@ -0,0 +1,157 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, fields, replace
|
||||
from typing import TypeVar
|
||||
|
||||
import numpy as np
|
||||
|
||||
from esm.utils.misc import concat_objects, slice_any_object
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SequentialDataclass(ABC):
|
||||
"""
|
||||
This is a builder on a dataclass that allows for automatic slicing and concatenation.
|
||||
|
||||
When representing multimodal data, we often have multiple datatypes which have sequence dimensions that are the same (e.g. the length of the protein).
|
||||
|
||||
When applying a transformation like a crop, we want to apply this to all tensors at the same time (e.g. crop the sequence, structure, and function).
|
||||
|
||||
We also have some fields that are not sequential (like an id, or data source), which we don't want to crop.
|
||||
|
||||
The SequentialDataclass abstracts this cropping away, allowing you to define dataclasses that implement `__len__`, `__getitem__` and `concat` automatically.
|
||||
|
||||
This is done through the `metadata` field, which can take 3 values:
|
||||
`sequence` (bool): True or False, tells the dataclass whether this field is a sequential type. Default: False.
|
||||
`sequence_dim` (int): Which dimension is the sequential dimension (e.g. for a list of inverse folded sequences, we want to index each sequence in the list, not the list itself). Default: 0.
|
||||
`join_token` (Any): What token to use to join when concatenating elements. Default: None.
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Foo(SequentialDataclass):
|
||||
id: str
|
||||
sequence: str = field(metadata={"sequence": True, "join_token": "|"})
|
||||
tensor: torch.Tensor = field(metadata={"sequence": True, "join_token": torch.nan})
|
||||
|
||||
def __len__(self):
|
||||
# Must implement the __len__ method
|
||||
return len(self.sequence)
|
||||
|
||||
>>> foo = Foo(id="foo", sequence="ABCDE", tensor=torch.randn(5))
|
||||
Foo(id='foo', sequence='ABCDE', tensor=tensor([ 0.0252, -0.3335, -0.5143, 0.0251, -1.0717]))
|
||||
|
||||
>>> foo[1:4]
|
||||
Foo(id='foo', sequence='BCD', tensor=tensor([-0.3335, -0.5143, 0.0251]))
|
||||
|
||||
>>> foo[np.arange(5) < 3]
|
||||
Foo(id='foo', sequence='ABC', tensor=tensor([ 0.0252, -0.3335, -0.5143]))
|
||||
|
||||
>>> Foo.concat([foo[:2], foo[3:]])
|
||||
Foo(id='foo', sequence='AB|DE', tensor=tensor([ 0.0252, -0.3335, nan, 0.0251, -1.0717]))
|
||||
|
||||
# Trying to create a type where the sequence lengths do not match raises an error
|
||||
>>> foo = Foo(id="foo", sequence="ABCDE", tensor=torch.randn(6))
|
||||
ValueError: Mismatch in sequence length for field: tensor. Expected 5, received 6
|
||||
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
self._check_sequence_lengths_match()
|
||||
|
||||
@abstractmethod
|
||||
def __len__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __getitem__(self, idx: int | list[int] | slice | np.ndarray):
|
||||
updated_fields = {}
|
||||
if isinstance(idx, int):
|
||||
# make it so that things remain sequential
|
||||
idx = [idx]
|
||||
|
||||
for fld in fields(self):
|
||||
if fld.metadata.get("sequence", False):
|
||||
# this is a sequence, should be the same length as all other sequences
|
||||
sequence_dim = fld.metadata.get("sequence_dim", 0)
|
||||
value = getattr(self, fld.name)
|
||||
if value is None:
|
||||
continue
|
||||
match sequence_dim:
|
||||
case 0:
|
||||
# sequence is first dimension
|
||||
value = getattr(self, fld.name)
|
||||
value = slice_any_object(value, idx)
|
||||
updated_fields[fld.name] = value
|
||||
case 1:
|
||||
new_value = [slice_any_object(item, idx) for item in value]
|
||||
updated_fields[fld.name] = value.__class__(new_value)
|
||||
case _:
|
||||
raise NotImplementedError(
|
||||
"Arbitrary slicing for different sequence length fields is not implemented"
|
||||
)
|
||||
|
||||
return replace(self, **updated_fields)
|
||||
|
||||
def _check_sequence_lengths_match(self):
|
||||
"""Checks if sequence lengths of all "sequence" fields match."""
|
||||
for fld in fields(self):
|
||||
if fld.metadata.get("sequence", False) and fld.name != "complex":
|
||||
# this is a sequence, should be the same length as all other sequences
|
||||
sequence_dim = fld.metadata.get("sequence_dim", 0)
|
||||
value = getattr(self, fld.name)
|
||||
if value is None:
|
||||
continue
|
||||
match sequence_dim:
|
||||
case 0:
|
||||
# sequence is first dimension
|
||||
value = getattr(self, fld.name)
|
||||
if len(value) != len(self):
|
||||
raise ValueError(
|
||||
f"Mismatch in sequence length for field: {fld.name}. Expected {len(self)}, received {len(value)}"
|
||||
)
|
||||
case 1:
|
||||
for item in value:
|
||||
if len(item) != len(self):
|
||||
raise ValueError(
|
||||
f"Mismatch in sequence length for field: {fld.name}. Expected {len(self)}, received {len(item)}"
|
||||
)
|
||||
case _:
|
||||
raise NotImplementedError(
|
||||
"Arbitrary matching for different sequence length fields is not implemented"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def concat(cls, items: list[T], **kwargs) -> T:
|
||||
updated_fields = {}
|
||||
for fld in fields(cls):
|
||||
if fld.metadata.get("sequence", False):
|
||||
# this is a sequence, should be the same length as all other sequences
|
||||
sequence_dim = fld.metadata.get("sequence_dim", 0)
|
||||
join_value = fld.metadata.get("join_token", None)
|
||||
if getattr(items[0], fld.name) is None:
|
||||
continue
|
||||
values = [getattr(item, fld.name) for item in items]
|
||||
match sequence_dim:
|
||||
case 0:
|
||||
# sequence is first dimension
|
||||
value = concat_objects(values, join_value)
|
||||
updated_fields[fld.name] = value
|
||||
case 1:
|
||||
new_value = [
|
||||
concat_objects(item, join_value) for item in zip(*values)
|
||||
]
|
||||
updated_fields[fld.name] = getattr(
|
||||
items[0], fld.name
|
||||
).__class__(new_value)
|
||||
case _:
|
||||
raise NotImplementedError(
|
||||
"Arbitrary joining for different sequence length fields is not implemented"
|
||||
)
|
||||
updated_fields.update(kwargs)
|
||||
|
||||
return replace(
|
||||
items[0], # type: ignore
|
||||
**updated_fields,
|
||||
)
|
||||
@@ -6,7 +6,9 @@ from typing import Any, ClassVar, Protocol, TypeVar
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from esm.utils.structure.protein_structure import compute_affine_and_rmsd
|
||||
from esm.utils.structure.protein_structure import (
|
||||
compute_affine_and_rmsd,
|
||||
)
|
||||
|
||||
|
||||
class Alignable(Protocol):
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import numpy as np
|
||||
|
||||
from esm.utils.structure.protein_structure import index_by_atom_name
|
||||
from esm.utils.structure.protein_structure import (
|
||||
index_by_atom_name,
|
||||
)
|
||||
|
||||
|
||||
class AtomIndexer:
|
||||
|
||||
96
esm/utils/structure/input_builder.py
Normal file
96
esm/utils/structure/input_builder.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Sequence
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class Modification:
|
||||
position: int # zero-indexed
|
||||
ccd: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProteinInput:
|
||||
id: str | list[str]
|
||||
sequence: str
|
||||
modifications: list[Modification] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RNAInput:
|
||||
id: str | list[str]
|
||||
sequence: str
|
||||
modifications: list[Modification] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DNAInput:
|
||||
id: str | list[str]
|
||||
sequence: str
|
||||
modifications: list[Modification] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LigandInput:
|
||||
id: str | list[str]
|
||||
smiles: str
|
||||
ccd: list[str] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DistogramConditioning:
|
||||
chain_id: str
|
||||
distogram: np.ndarray
|
||||
|
||||
|
||||
@dataclass
|
||||
class PocketConditioning:
|
||||
binder_chain_id: str
|
||||
contacts: list[tuple[str, int]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class StructurePredictionInput:
|
||||
sequences: Sequence[ProteinInput | RNAInput | DNAInput | LigandInput]
|
||||
pocket: PocketConditioning | None = None
|
||||
distogram_conditioning: list[DistogramConditioning] | None = None
|
||||
|
||||
|
||||
def serialize_structure_prediction_input(all_atom_input: StructurePredictionInput):
|
||||
def create_chain_data(seq_input, chain_type: str) -> dict[str, Any]:
|
||||
chain_data: dict[str, Any] = {
|
||||
"sequence": seq_input.sequence,
|
||||
"id": seq_input.id,
|
||||
"type": chain_type,
|
||||
}
|
||||
if hasattr(seq_input, "modifications") and seq_input.modifications:
|
||||
mods = [
|
||||
{"position": mod.position, "ccd": mod.ccd}
|
||||
for mod in seq_input.modifications
|
||||
]
|
||||
chain_data["modifications"] = mods
|
||||
return chain_data
|
||||
|
||||
sequences = []
|
||||
for seq_input in all_atom_input.sequences:
|
||||
if isinstance(seq_input, ProteinInput):
|
||||
sequences.append(create_chain_data(seq_input, "protein"))
|
||||
elif isinstance(seq_input, RNAInput):
|
||||
sequences.append(create_chain_data(seq_input, "rna"))
|
||||
elif isinstance(seq_input, DNAInput):
|
||||
sequences.append(create_chain_data(seq_input, "dna"))
|
||||
elif isinstance(seq_input, LigandInput):
|
||||
sequences.append(
|
||||
{
|
||||
"smiles": seq_input.smiles,
|
||||
"id": seq_input.id,
|
||||
"ccd": seq_input.ccd,
|
||||
"type": "ligand",
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported sequence input type: {type(seq_input)}")
|
||||
|
||||
return {"sequences": sequences}
|
||||
@@ -264,7 +264,7 @@ def compute_lddt_ca(
|
||||
if all_atom_pred_pos.dim() != 3:
|
||||
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
|
||||
all_atom_positions = all_atom_positions[..., ca_pos, :]
|
||||
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
|
||||
all_atom_mask = all_atom_mask[..., ca_pos]
|
||||
|
||||
return compute_lddt(
|
||||
all_atom_pred_pos,
|
||||
|
||||
944
esm/utils/structure/molecular_complex.py
Normal file
944
esm/utils/structure/molecular_complex.py
Normal file
@@ -0,0 +1,944 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from subprocess import check_output
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import TYPE_CHECKING, Any, List
|
||||
|
||||
import biotite.structure.io.pdbx as pdbx
|
||||
import brotli
|
||||
import msgpack
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from esm.utils import residue_constants
|
||||
from esm.utils.structure.metrics import (
|
||||
compute_lddt,
|
||||
compute_rmsd,
|
||||
)
|
||||
from esm.utils.structure.protein_complex import (
|
||||
ProteinComplex,
|
||||
ProteinComplexMetadata,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MolecularComplexResult:
|
||||
"""Result of molecular complex folding"""
|
||||
|
||||
complex: MolecularComplex
|
||||
plddt: torch.Tensor | None = None
|
||||
ptm: float | None = None
|
||||
iptm: float | None = None
|
||||
pae: torch.Tensor | None = None
|
||||
distogram: torch.Tensor | None = None
|
||||
pair_chains_iptm: torch.Tensor | None = None
|
||||
output_embedding_sequence: torch.Tensor | None = None
|
||||
output_embedding_pair_pooled: torch.Tensor | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MolecularComplexMetadata:
|
||||
"""Metadata for MolecularComplex objects."""
|
||||
|
||||
entity_lookup: dict[int, str]
|
||||
chain_lookup: dict[int, str]
|
||||
assembly_composition: dict[str, list[str]] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Molecule:
|
||||
"""Represents a single molecule/token within a MolecularComplex."""
|
||||
|
||||
token: str
|
||||
token_idx: int
|
||||
atom_positions: np.ndarray # [N_atoms, 3]
|
||||
atom_elements: np.ndarray # [N_atoms] element strings
|
||||
residue_type: int
|
||||
molecule_type: int # PROTEIN=0, RNA=1, DNA=2, LIGAND=3
|
||||
confidence: float
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MolecularComplex:
|
||||
"""
|
||||
Dataclass representing a molecular complex with support for proteins, nucleic acids, and ligands.
|
||||
|
||||
Uses a flat atom representation with token-based sequence indexing, supporting all atom types
|
||||
beyond the traditional atom37 protein representation.
|
||||
"""
|
||||
|
||||
id: str
|
||||
sequence: List[str] # Token sequence like ['MET', 'LYS', 'A', 'G', 'ATP']
|
||||
|
||||
# Flat atom arrays - simplified representation
|
||||
atom_positions: np.ndarray # [N_atoms, 3] 3D coordinates
|
||||
atom_elements: np.ndarray # [N_atoms] element strings
|
||||
|
||||
# Token-to-atom mapping for efficient access
|
||||
token_to_atoms: np.ndarray # [N_tokens, 2] start/end indices into atoms array
|
||||
|
||||
# Confidence data
|
||||
plddt: np.ndarray # Per-token confidence scores [N_tokens]
|
||||
|
||||
# Metadata
|
||||
metadata: MolecularComplexMetadata
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate array dimensions."""
|
||||
n_tokens = len(self.sequence)
|
||||
assert (
|
||||
self.token_to_atoms.shape[0] == n_tokens
|
||||
), f"token_to_atoms shape {self.token_to_atoms.shape} != {n_tokens} tokens"
|
||||
assert (
|
||||
self.plddt.shape[0] == n_tokens
|
||||
), f"plddt shape {self.plddt.shape} != {n_tokens} tokens"
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return number of tokens."""
|
||||
return len(self.sequence)
|
||||
|
||||
def __getitem__(self, idx: int) -> Molecule:
|
||||
"""Access individual molecules/tokens by index."""
|
||||
if idx >= len(self.sequence) or idx < 0:
|
||||
raise IndexError(
|
||||
f"Token index {idx} out of range for {len(self.sequence)} tokens"
|
||||
)
|
||||
|
||||
token = self.sequence[idx]
|
||||
start_atom, end_atom = self.token_to_atoms[idx]
|
||||
|
||||
# Extract atom data for this token
|
||||
token_atom_positions = self.atom_positions[start_atom:end_atom]
|
||||
token_atom_elements = self.atom_elements[start_atom:end_atom]
|
||||
|
||||
# Default values for residue/molecule type (would be extended based on actual implementation)
|
||||
residue_type = 0 # Default to standard residue
|
||||
molecule_type = 0 # Default to protein
|
||||
|
||||
return Molecule(
|
||||
token=token,
|
||||
token_idx=idx,
|
||||
atom_positions=token_atom_positions,
|
||||
atom_elements=token_atom_elements,
|
||||
residue_type=residue_type,
|
||||
molecule_type=molecule_type,
|
||||
confidence=self.plddt[idx],
|
||||
)
|
||||
|
||||
@property
|
||||
def atom_coordinates(self) -> np.ndarray:
|
||||
"""Get flat array of all atom coordinates [N_atoms, 3]."""
|
||||
return self.atom_positions
|
||||
|
||||
# Conversion methods
|
||||
@classmethod
|
||||
def from_protein_complex(cls, pc: ProteinComplex) -> "MolecularComplex":
|
||||
"""Convert a ProteinComplex to MolecularComplex.
|
||||
|
||||
Args:
|
||||
pc: ProteinComplex object with atom37 representation
|
||||
|
||||
Returns:
|
||||
MolecularComplex with flat atom arrays and token-based indexing
|
||||
"""
|
||||
from esm.utils import residue_constants
|
||||
|
||||
# Extract sequence without chain breaks
|
||||
sequence_no_breaks = pc.sequence.replace("|", "")
|
||||
sequence_tokens = [
|
||||
residue_constants.restype_1to3.get(aa, "UNK") for aa in sequence_no_breaks
|
||||
]
|
||||
|
||||
# Convert atom37 to flat arrays
|
||||
flat_positions = []
|
||||
flat_elements = []
|
||||
token_to_atoms = []
|
||||
|
||||
atom_idx = 0
|
||||
residue_idx = 0
|
||||
|
||||
for i, aa in enumerate(pc.sequence):
|
||||
if aa == "|":
|
||||
# Skip chain break tokens
|
||||
continue
|
||||
|
||||
# Get atom37 positions and mask for this residue
|
||||
res_positions = pc.atom37_positions[residue_idx] # [37, 3]
|
||||
res_mask = pc.atom37_mask[residue_idx] # [37]
|
||||
|
||||
# Track start position for this token
|
||||
token_start = atom_idx
|
||||
|
||||
# Process each atom type in atom37 representation
|
||||
for atom_type_idx, atom_name in enumerate(residue_constants.atom_types):
|
||||
if res_mask[atom_type_idx]: # Atom is present
|
||||
# Add position
|
||||
flat_positions.append(res_positions[atom_type_idx])
|
||||
|
||||
# Determine element from atom name
|
||||
element = (
|
||||
atom_name[0] if atom_name else "C"
|
||||
) # First character is element
|
||||
flat_elements.append(element)
|
||||
|
||||
atom_idx += 1
|
||||
|
||||
# Record token-to-atom mapping [start_idx, end_idx)
|
||||
token_to_atoms.append([token_start, atom_idx])
|
||||
residue_idx += 1
|
||||
|
||||
# Convert to numpy arrays
|
||||
atom_positions = np.array(flat_positions, dtype=np.float32)
|
||||
atom_elements = np.array(flat_elements, dtype=object)
|
||||
token_to_atoms_array = np.array(token_to_atoms, dtype=np.int32)
|
||||
|
||||
# Extract confidence scores (skip chain breaks)
|
||||
confidence_scores = []
|
||||
residue_idx = 0
|
||||
for aa in pc.sequence:
|
||||
if aa != "|":
|
||||
confidence_scores.append(pc.confidence[residue_idx])
|
||||
residue_idx += 1
|
||||
|
||||
confidence_array = np.array(confidence_scores, dtype=np.float32)
|
||||
|
||||
# Create metadata - convert entity IDs to strings for MolecularComplexMetadata
|
||||
entity_lookup_str = {k: str(v) for k, v in pc.metadata.entity_lookup.items()}
|
||||
metadata = MolecularComplexMetadata(
|
||||
entity_lookup=entity_lookup_str,
|
||||
chain_lookup=pc.metadata.chain_lookup,
|
||||
assembly_composition=pc.metadata.assembly_composition,
|
||||
)
|
||||
|
||||
return cls(
|
||||
id=pc.id,
|
||||
sequence=sequence_tokens,
|
||||
atom_positions=atom_positions,
|
||||
atom_elements=atom_elements,
|
||||
token_to_atoms=token_to_atoms_array,
|
||||
plddt=confidence_array,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def to_protein_complex(self) -> ProteinComplex:
|
||||
"""Convert MolecularComplex back to ProteinComplex format.
|
||||
|
||||
Extracts only protein tokens and converts from flat atom representation
|
||||
back to atom37 format used by ProteinComplex.
|
||||
|
||||
Returns:
|
||||
ProteinComplex with protein residues only, excluding ligands/nucleic acids
|
||||
"""
|
||||
from esm.utils import residue_constants
|
||||
|
||||
# No need for element mapping - already using element characters
|
||||
|
||||
# Filter for protein tokens only (skip ligands, nucleic acids)
|
||||
protein_tokens = []
|
||||
protein_indices = []
|
||||
|
||||
for i, token in enumerate(self.sequence):
|
||||
# Check if token is a standard 3-letter amino acid code
|
||||
if token in residue_constants.restype_3to1:
|
||||
protein_tokens.append(token)
|
||||
protein_indices.append(i)
|
||||
|
||||
if not protein_tokens:
|
||||
raise ValueError("No protein tokens found in MolecularComplex")
|
||||
|
||||
n_residues = len(protein_tokens)
|
||||
|
||||
# Initialize atom37 arrays
|
||||
atom37_positions = np.full((n_residues, 37, 3), np.nan, dtype=np.float32)
|
||||
atom37_mask = np.zeros((n_residues, 37), dtype=bool)
|
||||
|
||||
# Convert tokens back to single-letter sequence
|
||||
single_letter_sequence = "".join(
|
||||
[residue_constants.restype_3to1[token] for token in protein_tokens]
|
||||
)
|
||||
|
||||
# Extract confidence scores for protein residues only
|
||||
protein_confidence = self.plddt[protein_indices]
|
||||
|
||||
# Convert flat atoms back to atom37 representation
|
||||
for res_idx, token_idx in enumerate(protein_indices):
|
||||
token = self.sequence[token_idx]
|
||||
start_atom, end_atom = self.token_to_atoms[token_idx]
|
||||
|
||||
# Get atom data for this residue
|
||||
res_atom_positions = self.atom_positions[start_atom:end_atom]
|
||||
|
||||
# Reconstruct atom37 representation by exactly reversing the forward conversion logic
|
||||
# In from_protein_complex, atoms are added in atom_types order if present in mask
|
||||
# So we need to reconstruct the mask and positions in the same order
|
||||
atom_count = 0
|
||||
for atom_type_idx, atom_name in enumerate(residue_constants.atom_types):
|
||||
# Check if this atom type exists for this residue and was present
|
||||
residue_atoms = residue_constants.residue_atoms.get(token, [])
|
||||
if atom_name in residue_atoms:
|
||||
# This atom type exists for this residue, so it should have been included
|
||||
if atom_count < len(res_atom_positions):
|
||||
atom37_positions[res_idx, atom_type_idx] = res_atom_positions[
|
||||
atom_count
|
||||
]
|
||||
atom37_mask[res_idx, atom_type_idx] = True
|
||||
atom_count += 1
|
||||
|
||||
# Create other required arrays for ProteinComplex
|
||||
# For simplicity, assume all protein residues belong to the same entity/chain
|
||||
entity_id = np.zeros(n_residues, dtype=np.int64)
|
||||
chain_id = np.zeros(n_residues, dtype=np.int64)
|
||||
sym_id = np.zeros(n_residues, dtype=np.int64)
|
||||
residue_index = np.arange(1, n_residues + 1, dtype=np.int64)
|
||||
insertion_code = np.array([""] * n_residues, dtype=object)
|
||||
|
||||
# Create simplified protein complex metadata
|
||||
# Map the first entity/chain from molecular complex metadata
|
||||
protein_metadata = ProteinComplexMetadata(
|
||||
entity_lookup={0: 1}, # Single entity (int for ProteinComplexMetadata)
|
||||
chain_lookup={0: "A"}, # Single chain
|
||||
assembly_composition=self.metadata.assembly_composition,
|
||||
)
|
||||
|
||||
return ProteinComplex(
|
||||
id=self.id,
|
||||
sequence=single_letter_sequence,
|
||||
entity_id=entity_id,
|
||||
chain_id=chain_id,
|
||||
sym_id=sym_id,
|
||||
residue_index=residue_index,
|
||||
insertion_code=insertion_code,
|
||||
atom37_positions=atom37_positions,
|
||||
atom37_mask=atom37_mask,
|
||||
confidence=protein_confidence,
|
||||
metadata=protein_metadata,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_mmcif(cls, inp: str, id: str | None = None) -> "MolecularComplex":
|
||||
"""Read MolecularComplex from mmcif file or string.
|
||||
|
||||
Args:
|
||||
inp: Path to mmCIF file or mmCIF content as string
|
||||
id: Optional identifier to assign to the complex
|
||||
|
||||
Returns:
|
||||
MolecularComplex with all molecules (proteins, ligands, nucleic acids)
|
||||
"""
|
||||
from io import StringIO
|
||||
|
||||
# Check if input is a file path or mmCIF string content
|
||||
if os.path.exists(inp):
|
||||
# Input is a file path
|
||||
mmcif_file = pdbx.CIFFile.read(inp)
|
||||
else:
|
||||
# Input is mmCIF string content
|
||||
mmcif_file = pdbx.CIFFile.read(StringIO(inp))
|
||||
|
||||
# Get structure - handle missing model information gracefully
|
||||
try:
|
||||
structure = pdbx.get_structure(mmcif_file, model=1)
|
||||
except (KeyError, ValueError):
|
||||
# Fallback for mmCIF files without model information
|
||||
try:
|
||||
structure = pdbx.get_structure(mmcif_file)
|
||||
except Exception:
|
||||
# Last resort: use the first available model or all atoms
|
||||
structure = pdbx.get_structure(mmcif_file, model=None)
|
||||
# Type hint for pyright - structure is an AtomArray which is iterable
|
||||
if TYPE_CHECKING:
|
||||
structure: Any = structure
|
||||
|
||||
# Get entity information from mmCIF
|
||||
entity_info = {}
|
||||
try:
|
||||
# Access the first block in CIFFile
|
||||
block = mmcif_file[0]
|
||||
if "entity" in block:
|
||||
entity_category = block["entity"]
|
||||
if "id" in entity_category and "type" in entity_category:
|
||||
entity_ids = entity_category["id"]
|
||||
entity_types = entity_category["type"]
|
||||
# Convert CIFColumn to list for iteration
|
||||
if hasattr(entity_ids, "__iter__") and hasattr(
|
||||
entity_types, "__iter__"
|
||||
):
|
||||
# Type annotation to help pyright understand these are iterable
|
||||
entity_ids_list = list(entity_ids) # type: ignore
|
||||
entity_types_list = list(entity_types) # type: ignore
|
||||
for eid, etype in zip(entity_ids_list, entity_types_list):
|
||||
entity_info[eid] = etype
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Initialize arrays for flat atom representation
|
||||
sequence_tokens = []
|
||||
flat_positions = []
|
||||
flat_elements = []
|
||||
token_to_atoms = []
|
||||
confidence_scores = []
|
||||
|
||||
atom_idx = 0
|
||||
|
||||
# Group atoms by chain and residue
|
||||
chain_residue_groups = {}
|
||||
for atom in structure:
|
||||
chain_id = atom.chain_id
|
||||
res_id = atom.res_id
|
||||
res_name = atom.res_name
|
||||
|
||||
if chain_id not in chain_residue_groups:
|
||||
chain_residue_groups[chain_id] = {}
|
||||
if res_id not in chain_residue_groups[chain_id]:
|
||||
chain_residue_groups[chain_id][res_id] = {
|
||||
"atoms": [],
|
||||
"res_name": res_name,
|
||||
"is_hetero": atom.hetero,
|
||||
}
|
||||
chain_residue_groups[chain_id][res_id]["atoms"].append(atom)
|
||||
|
||||
# Process each chain and residue
|
||||
for chain_id in sorted(chain_residue_groups.keys()):
|
||||
residues = chain_residue_groups[chain_id]
|
||||
|
||||
for res_id in sorted(residues.keys()):
|
||||
residue_data = residues[res_id]
|
||||
res_name = residue_data["res_name"]
|
||||
atoms = residue_data["atoms"]
|
||||
is_hetero = residue_data["is_hetero"]
|
||||
|
||||
# Skip water molecules
|
||||
if res_name == "HOH":
|
||||
continue
|
||||
|
||||
# Determine token name
|
||||
if not is_hetero and res_name in residue_constants.restype_3to1:
|
||||
# Standard amino acid
|
||||
token_name = res_name
|
||||
elif res_name in ["A", "T", "G", "C", "U", "DA", "DT", "DG", "DC"]:
|
||||
# Nucleotide
|
||||
token_name = res_name
|
||||
else:
|
||||
# Ligand or other molecule
|
||||
token_name = res_name
|
||||
|
||||
sequence_tokens.append(token_name)
|
||||
token_start = atom_idx
|
||||
|
||||
# Add all atoms from this residue
|
||||
for atom in atoms:
|
||||
flat_positions.append(atom.coord)
|
||||
|
||||
# Get element character
|
||||
element = atom.element
|
||||
flat_elements.append(element)
|
||||
|
||||
atom_idx += 1
|
||||
|
||||
# Record token-to-atom mapping
|
||||
token_to_atoms.append([token_start, atom_idx])
|
||||
|
||||
# Add confidence score (B-factor if available, otherwise 1.0)
|
||||
bfactor = getattr(atoms[0], "b_factor", 50.0) if atoms else 50.0
|
||||
confidence_scores.append(min(bfactor / 100.0, 1.0))
|
||||
|
||||
# Convert to numpy arrays
|
||||
if not flat_positions:
|
||||
# Create minimal arrays if no atoms found
|
||||
atom_positions = np.zeros((0, 3), dtype=np.float32)
|
||||
atom_elements = np.zeros(0, dtype=object)
|
||||
token_to_atoms_array = np.zeros((len(sequence_tokens), 2), dtype=np.int32)
|
||||
else:
|
||||
atom_positions = np.array(flat_positions, dtype=np.float32)
|
||||
atom_elements = np.array(flat_elements, dtype=object)
|
||||
token_to_atoms_array = np.array(token_to_atoms, dtype=np.int32)
|
||||
|
||||
confidence_array = np.array(confidence_scores, dtype=np.float32)
|
||||
|
||||
# Create metadata
|
||||
metadata = MolecularComplexMetadata(
|
||||
entity_lookup=entity_info,
|
||||
chain_lookup={
|
||||
i: chain_id for i, chain_id in enumerate(chain_residue_groups.keys())
|
||||
},
|
||||
assembly_composition=None,
|
||||
)
|
||||
|
||||
# Set complex ID - if input was a path, use the stem; otherwise use default
|
||||
if os.path.exists(inp):
|
||||
complex_id = id or Path(inp).stem
|
||||
else:
|
||||
complex_id = id or "complex_from_string"
|
||||
|
||||
return cls(
|
||||
id=complex_id,
|
||||
sequence=sequence_tokens,
|
||||
atom_positions=atom_positions,
|
||||
atom_elements=atom_elements,
|
||||
token_to_atoms=token_to_atoms_array,
|
||||
plddt=confidence_array,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def to_mmcif(self) -> str:
|
||||
"""Write MolecularComplex to mmcif string.
|
||||
|
||||
Returns:
|
||||
String representation of the complex in mmCIF format
|
||||
"""
|
||||
# No need for element mapping - already using element characters
|
||||
|
||||
lines = []
|
||||
|
||||
# Header
|
||||
lines.append(f"data_{self.id}")
|
||||
lines.append("#")
|
||||
lines.append(f"_entry.id {self.id}")
|
||||
lines.append("#")
|
||||
|
||||
# Structure metadata
|
||||
lines.append("_struct.entry_id {}".format(self.id))
|
||||
lines.append("_struct.title 'Protein Structure'")
|
||||
lines.append("#")
|
||||
|
||||
# Entity information
|
||||
entity_id = 1
|
||||
chain_counter = 0
|
||||
lines.append("loop_")
|
||||
lines.append("_entity.id")
|
||||
lines.append("_entity.type")
|
||||
lines.append("_entity.pdbx_description")
|
||||
|
||||
# Determine entities based on sequence
|
||||
protein_tokens = []
|
||||
other_tokens = []
|
||||
|
||||
for i, token in enumerate(self.sequence):
|
||||
if token in residue_constants.restype_3to1:
|
||||
protein_tokens.append((i, token))
|
||||
else:
|
||||
other_tokens.append((i, token))
|
||||
|
||||
if protein_tokens:
|
||||
lines.append(f"{entity_id} polymer 'Protein chain'")
|
||||
entity_id += 1
|
||||
|
||||
for token in set(token for _, token in other_tokens):
|
||||
lines.append(f"{entity_id} non-polymer 'Ligand {token}'")
|
||||
entity_id += 1
|
||||
|
||||
lines.append("#")
|
||||
|
||||
# Chain assignments
|
||||
lines.append("loop_")
|
||||
lines.append("_struct_asym.id")
|
||||
lines.append("_struct_asym.entity_id")
|
||||
|
||||
chain_id = "A"
|
||||
if protein_tokens:
|
||||
lines.append(f"{chain_id} 1")
|
||||
chain_counter += 1
|
||||
chain_id = chr(ord(chain_id) + 1)
|
||||
|
||||
entity_id = 2
|
||||
for token in set(token for _, token in other_tokens):
|
||||
lines.append(f"{chain_id} {entity_id}")
|
||||
entity_id += 1
|
||||
chain_counter += 1
|
||||
if chain_counter < 26:
|
||||
chain_id = chr(ord(chain_id) + 1)
|
||||
|
||||
lines.append("#")
|
||||
|
||||
# Atom site information
|
||||
lines.append("loop_")
|
||||
lines.append("_atom_site.group_PDB")
|
||||
lines.append("_atom_site.id")
|
||||
lines.append("_atom_site.type_symbol")
|
||||
lines.append("_atom_site.label_atom_id")
|
||||
lines.append("_atom_site.label_alt_id")
|
||||
lines.append("_atom_site.label_comp_id")
|
||||
lines.append("_atom_site.label_asym_id")
|
||||
lines.append("_atom_site.label_entity_id")
|
||||
lines.append("_atom_site.label_seq_id")
|
||||
lines.append("_atom_site.pdbx_PDB_ins_code")
|
||||
lines.append("_atom_site.Cartn_x")
|
||||
lines.append("_atom_site.Cartn_y")
|
||||
lines.append("_atom_site.Cartn_z")
|
||||
lines.append("_atom_site.occupancy")
|
||||
lines.append("_atom_site.B_iso_or_equiv")
|
||||
lines.append("_atom_site.pdbx_PDB_model_num")
|
||||
lines.append("_atom_site.auth_seq_id")
|
||||
lines.append("_atom_site.auth_comp_id")
|
||||
lines.append("_atom_site.auth_asym_id")
|
||||
lines.append("_atom_site.auth_atom_id")
|
||||
|
||||
atom_id = 1
|
||||
seq_id = 1
|
||||
chain_id = "A"
|
||||
entity_id = 1
|
||||
|
||||
for token_idx, token in enumerate(self.sequence):
|
||||
start_atom, end_atom = self.token_to_atoms[token_idx]
|
||||
|
||||
# Determine if this is a protein residue or ligand
|
||||
is_protein = token in residue_constants.restype_3to1
|
||||
group_pdb = "ATOM" if is_protein else "HETATM"
|
||||
current_entity_id = 1 if is_protein else 2 # Simplified entity assignment
|
||||
current_chain_id = "A" if is_protein else "B" # Simplified chain assignment
|
||||
|
||||
# Create atom names for this token
|
||||
atom_names = []
|
||||
if is_protein:
|
||||
# Use standard protein atom names
|
||||
res_atoms = residue_constants.residue_atoms.get(
|
||||
token, ["N", "CA", "C", "O"]
|
||||
)
|
||||
atom_names = res_atoms[: end_atom - start_atom]
|
||||
else:
|
||||
# Generate generic atom names for ligands
|
||||
for i in range(end_atom - start_atom):
|
||||
atom_names.append(f"C{i+1}")
|
||||
|
||||
# Pad atom names if needed
|
||||
while len(atom_names) < (end_atom - start_atom):
|
||||
atom_names.append(f"X{len(atom_names)+1}")
|
||||
|
||||
# Write atoms for this token
|
||||
for atom_idx_in_token, global_atom_idx in enumerate(
|
||||
range(start_atom, end_atom)
|
||||
):
|
||||
pos = self.atom_positions[global_atom_idx]
|
||||
element_char = self.atom_elements[global_atom_idx]
|
||||
element_symbol = element_char if isinstance(element_char, str) else "C"
|
||||
|
||||
atom_name = (
|
||||
atom_names[atom_idx_in_token]
|
||||
if atom_idx_in_token < len(atom_names)
|
||||
else f"X{atom_idx_in_token+1}"
|
||||
)
|
||||
|
||||
# Format atom site line
|
||||
bfactor = (
|
||||
self.plddt[token_idx] * 100.0
|
||||
if len(self.plddt) > token_idx
|
||||
else 50.0
|
||||
)
|
||||
|
||||
line = (
|
||||
f"{group_pdb:<6} {atom_id:>5} {element_symbol:<2} {atom_name:<4} . "
|
||||
f"{token:<3} {current_chain_id} {current_entity_id} {seq_id:>3} ? "
|
||||
f"{pos[0]:>8.3f} {pos[1]:>8.3f} {pos[2]:>8.3f} 1.00 {bfactor:>6.2f} 1 "
|
||||
f"{seq_id:>3} {token:<3} {current_chain_id} {atom_name:<4}"
|
||||
)
|
||||
lines.append(line)
|
||||
atom_id += 1
|
||||
|
||||
seq_id += 1
|
||||
|
||||
lines.append("#")
|
||||
return "\n".join(lines)
|
||||
|
||||
def dockq(self, native: "MolecularComplex") -> Any:
|
||||
"""Compute DockQ score against native structure.
|
||||
|
||||
Args:
|
||||
native: Native MolecularComplex to compute DockQ against
|
||||
|
||||
Returns:
|
||||
DockQ result containing score and alignment information
|
||||
"""
|
||||
# Imports moved to top of file
|
||||
|
||||
# Convert both complexes to ProteinComplex format for DockQ computation
|
||||
# This extracts only the protein portion and converts to PDB format
|
||||
try:
|
||||
self_pc = self.to_protein_complex()
|
||||
native_pc = native.to_protein_complex()
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Cannot convert MolecularComplex to ProteinComplex for DockQ: {e}"
|
||||
)
|
||||
|
||||
# Normalize chain IDs for PDB compatibility
|
||||
self_pc = self_pc.normalize_chain_ids_for_pdb()
|
||||
native_pc = native_pc.normalize_chain_ids_for_pdb()
|
||||
|
||||
# Use the existing ProteinComplex.dockq() method
|
||||
try:
|
||||
dockq_result = self_pc.dockq(native_pc)
|
||||
return dockq_result
|
||||
except Exception:
|
||||
# Fallback to manual DockQ computation if ProteinComplex.dockq() fails
|
||||
return self._compute_dockq_manual(native)
|
||||
|
||||
def _compute_dockq_manual(self, native: "MolecularComplex") -> Any:
|
||||
"""Manual DockQ computation fallback."""
|
||||
# Imports moved to top of file
|
||||
|
||||
# Convert both complexes to ProteinComplex format
|
||||
try:
|
||||
self_pc = self.to_protein_complex()
|
||||
native_pc = native.to_protein_complex()
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Cannot convert MolecularComplex to ProteinComplex for DockQ: {e}"
|
||||
)
|
||||
|
||||
# Normalize chain IDs for PDB compatibility
|
||||
self_pc = self_pc.normalize_chain_ids_for_pdb()
|
||||
native_pc = native_pc.normalize_chain_ids_for_pdb()
|
||||
|
||||
# Write temporary PDB files and run DockQ
|
||||
with TemporaryDirectory() as tdir:
|
||||
dir_path = Path(tdir)
|
||||
self_pdb = dir_path / "self.pdb"
|
||||
native_pdb = dir_path / "native.pdb"
|
||||
|
||||
# Write PDB files
|
||||
self_pc.to_pdb(self_pdb)
|
||||
native_pc.to_pdb(native_pdb)
|
||||
|
||||
# Run DockQ
|
||||
try:
|
||||
output = check_output(["DockQ", str(self_pdb), str(native_pdb)])
|
||||
output_text = output.decode()
|
||||
|
||||
# Parse DockQ output
|
||||
lines = output_text.split("\n")
|
||||
|
||||
# Find the total DockQ score
|
||||
dockq_score = None
|
||||
for line in lines:
|
||||
if "Total DockQ" in line:
|
||||
match = re.search(r"Total DockQ.*: ([\d.]+)", line)
|
||||
if match:
|
||||
dockq_score = float(match.group(1))
|
||||
break
|
||||
|
||||
if dockq_score is None:
|
||||
# Try to find individual DockQ scores
|
||||
for line in lines:
|
||||
if line.startswith("DockQ") and ":" in line:
|
||||
try:
|
||||
dockq_score = float(line.split(":")[1].strip())
|
||||
break
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
if dockq_score is None:
|
||||
raise ValueError("Could not parse DockQ score from output")
|
||||
|
||||
# Return a simple result structure
|
||||
return {
|
||||
"total_dockq": dockq_score,
|
||||
"raw_output": output_text,
|
||||
"aligned": self, # Return self as aligned structure
|
||||
}
|
||||
|
||||
except FileNotFoundError:
|
||||
raise RuntimeError(
|
||||
"DockQ is not installed. Please install DockQ to use this method."
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"DockQ computation failed: {e}")
|
||||
|
||||
def rmsd(self, target: "MolecularComplex", **kwargs) -> float:
|
||||
"""Compute RMSD against target structure.
|
||||
|
||||
Args:
|
||||
target: Target MolecularComplex to compute RMSD against
|
||||
**kwargs: Additional arguments passed to compute_rmsd
|
||||
|
||||
Returns:
|
||||
float: RMSD value between the two structures
|
||||
"""
|
||||
# Imports moved to top of file
|
||||
|
||||
# Ensure both complexes have the same number of tokens
|
||||
if len(self) != len(target):
|
||||
raise ValueError(
|
||||
f"Complexes must have the same number of tokens: {len(self)} vs {len(target)}"
|
||||
)
|
||||
|
||||
# Extract center positions for each token (using centroid of atoms)
|
||||
mobile_coords = []
|
||||
target_coords = []
|
||||
atom_mask = []
|
||||
|
||||
for i in range(len(self)):
|
||||
# Get atom positions for this token
|
||||
mobile_start, mobile_end = self.token_to_atoms[i]
|
||||
target_start, target_end = target.token_to_atoms[i]
|
||||
|
||||
# Extract atom positions
|
||||
mobile_atoms = self.atom_positions[mobile_start:mobile_end]
|
||||
target_atoms = target.atom_positions[target_start:target_end]
|
||||
|
||||
# Check if both tokens have atoms
|
||||
if len(mobile_atoms) == 0 or len(target_atoms) == 0:
|
||||
# Skip tokens with no atoms
|
||||
continue
|
||||
|
||||
# For simplicity, use the centroid of atoms as the representative position
|
||||
mobile_center = mobile_atoms.mean(axis=0)
|
||||
target_center = target_atoms.mean(axis=0)
|
||||
|
||||
mobile_coords.append(mobile_center)
|
||||
target_coords.append(target_center)
|
||||
atom_mask.append(True)
|
||||
|
||||
if len(mobile_coords) == 0:
|
||||
raise ValueError("No valid atoms found for RMSD computation")
|
||||
|
||||
# Convert to tensors
|
||||
mobile_tensor = torch.from_numpy(np.stack(mobile_coords, axis=0)).unsqueeze(
|
||||
0
|
||||
) # [1, N, 3]
|
||||
target_tensor = torch.from_numpy(np.stack(target_coords, axis=0)).unsqueeze(
|
||||
0
|
||||
) # [1, N, 3]
|
||||
mask_tensor = torch.tensor(atom_mask, dtype=torch.bool).unsqueeze(0) # [1, N]
|
||||
|
||||
# Compute RMSD using existing infrastructure
|
||||
rmsd_value = compute_rmsd(
|
||||
mobile=mobile_tensor,
|
||||
target=target_tensor,
|
||||
atom_exists_mask=mask_tensor,
|
||||
reduction="batch",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return float(rmsd_value)
|
||||
|
||||
def lddt_ca(self, target: "MolecularComplex", **kwargs) -> float:
|
||||
"""Compute LDDT score against target structure.
|
||||
|
||||
Args:
|
||||
target: Target MolecularComplex to compute LDDT against
|
||||
**kwargs: Additional arguments passed to compute_lddt
|
||||
|
||||
Returns:
|
||||
float: LDDT value between the two structures
|
||||
"""
|
||||
# Imports moved to top of file
|
||||
|
||||
# Ensure both complexes have the same number of tokens
|
||||
if len(self) != len(target):
|
||||
raise ValueError(
|
||||
f"Complexes must have the same number of tokens: {len(self)} vs {len(target)}"
|
||||
)
|
||||
|
||||
# Extract center positions for each token (using centroid of atoms)
|
||||
mobile_coords = []
|
||||
target_coords = []
|
||||
atom_mask = []
|
||||
|
||||
for i in range(len(self)):
|
||||
# Get atom positions for this token
|
||||
mobile_start, mobile_end = self.token_to_atoms[i]
|
||||
target_start, target_end = target.token_to_atoms[i]
|
||||
|
||||
# Extract atom positions
|
||||
mobile_atoms = self.atom_positions[mobile_start:mobile_end]
|
||||
target_atoms = target.atom_positions[target_start:target_end]
|
||||
|
||||
# Check if both tokens have atoms
|
||||
if len(mobile_atoms) == 0 or len(target_atoms) == 0:
|
||||
# Skip tokens with no atoms
|
||||
mobile_coords.append(np.full(3, np.nan))
|
||||
target_coords.append(np.full(3, np.nan))
|
||||
atom_mask.append(False)
|
||||
continue
|
||||
|
||||
# For simplicity, use the centroid of atoms as the representative position
|
||||
mobile_center = mobile_atoms.mean(axis=0)
|
||||
target_center = target_atoms.mean(axis=0)
|
||||
|
||||
mobile_coords.append(mobile_center)
|
||||
target_coords.append(target_center)
|
||||
atom_mask.append(True)
|
||||
|
||||
if not any(atom_mask):
|
||||
raise ValueError("No valid atoms found for LDDT computation")
|
||||
|
||||
# Convert to tensors
|
||||
mobile_tensor = torch.from_numpy(np.stack(mobile_coords, axis=0)).unsqueeze(
|
||||
0
|
||||
) # [1, N, 3]
|
||||
target_tensor = torch.from_numpy(np.stack(target_coords, axis=0)).unsqueeze(
|
||||
0
|
||||
) # [1, N, 3]
|
||||
mask_tensor = torch.tensor(atom_mask, dtype=torch.bool).unsqueeze(0) # [1, N]
|
||||
|
||||
# Compute LDDT using existing infrastructure
|
||||
lddt_value = compute_lddt(
|
||||
all_atom_pred_pos=mobile_tensor,
|
||||
all_atom_positions=target_tensor,
|
||||
all_atom_mask=mask_tensor,
|
||||
per_residue=False, # Return overall LDDT score
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return float(lddt_value)
|
||||
|
||||
def state_dict(self):
|
||||
"""This state dict is optimized for storage, so it turns things to fp16 whenever
|
||||
possible and converts numpy arrays to lists for JSON serialization.
|
||||
"""
|
||||
dct = {k: v for k, v in vars(self).items()}
|
||||
for k, v in dct.items():
|
||||
if isinstance(v, np.ndarray):
|
||||
match v.dtype:
|
||||
case np.int64:
|
||||
dct[k] = v.astype(np.int32).tolist()
|
||||
case np.float64 | np.float32:
|
||||
dct[k] = v.astype(np.float16).tolist()
|
||||
case _:
|
||||
dct[k] = v.tolist()
|
||||
elif isinstance(v, MolecularComplexMetadata):
|
||||
dct[k] = asdict(v)
|
||||
|
||||
return dct
|
||||
|
||||
def to_blob(self) -> bytes:
|
||||
return brotli.compress(msgpack.dumps(self.state_dict()), quality=5)
|
||||
|
||||
@classmethod
|
||||
def from_state_dict(cls, dct):
|
||||
for k, v in dct.items():
|
||||
if isinstance(v, list) and k in [
|
||||
"atom_positions",
|
||||
"atom_elements",
|
||||
"token_to_atoms",
|
||||
"plddt",
|
||||
]:
|
||||
dct[k] = np.array(v)
|
||||
|
||||
for k, v in dct.items():
|
||||
if isinstance(v, np.ndarray):
|
||||
if k in ["atom_positions", "plddt"]:
|
||||
dct[k] = v.astype(np.float32)
|
||||
elif k in ["token_to_atoms"]:
|
||||
dct[k] = v.astype(np.int32)
|
||||
|
||||
dct["metadata"] = MolecularComplexMetadata(**dct["metadata"])
|
||||
return cls(**dct)
|
||||
|
||||
@classmethod
|
||||
def from_blob(cls, input: Path | str | io.BytesIO | bytes):
|
||||
match input:
|
||||
case Path() | str():
|
||||
bytes = Path(input).read_bytes()
|
||||
case io.BytesIO():
|
||||
bytes = input.getvalue()
|
||||
case _:
|
||||
bytes = input
|
||||
return cls.from_state_dict(
|
||||
msgpack.loads(brotli.decompress(bytes), strict_map_key=False)
|
||||
)
|
||||
@@ -25,13 +25,21 @@ from esm.utils.misc import slice_python_object_as_numpy
|
||||
from esm.utils.structure.affine3d import Affine3D
|
||||
from esm.utils.structure.aligner import Aligner
|
||||
from esm.utils.structure.atom_indexer import AtomIndexer
|
||||
from esm.utils.structure.metrics import compute_gdt_ts, compute_lddt_ca
|
||||
from esm.utils.structure.mmcif_parsing import MmcifWrapper, Residue
|
||||
from esm.utils.structure.metrics import (
|
||||
compute_gdt_ts,
|
||||
compute_lddt_ca,
|
||||
)
|
||||
from esm.utils.structure.mmcif_parsing import (
|
||||
MmcifWrapper,
|
||||
Residue,
|
||||
)
|
||||
from esm.utils.structure.normalize_coordinates import (
|
||||
apply_frame_to_coords,
|
||||
get_protein_normalization_frame,
|
||||
)
|
||||
from esm.utils.structure.protein_structure import index_by_atom_name
|
||||
from esm.utils.structure.protein_structure import (
|
||||
index_by_atom_name,
|
||||
)
|
||||
from esm.utils.types import PathOrBuffer
|
||||
|
||||
msgpack_numpy.patch()
|
||||
@@ -393,6 +401,7 @@ class ProteinChain:
|
||||
bytes = input
|
||||
return cls.from_state_dict(msgpack.loads(brotli.decompress(bytes)))
|
||||
|
||||
|
||||
def sasa(self, by_residue: bool = True):
|
||||
arr = self.atom_array_no_insertions
|
||||
sasa_per_atom = bs.sasa(arr) # type: ignore
|
||||
@@ -698,6 +707,7 @@ class ProteinChain:
|
||||
)
|
||||
return float(gdt_ts) if gdt_ts.numel() == 1 else gdt_ts.numpy().flatten()
|
||||
|
||||
|
||||
@classmethod
|
||||
def chain_iterable_from_mmcif(
|
||||
cls,
|
||||
|
||||
@@ -32,8 +32,14 @@ from esm.utils.misc import slice_python_object_as_numpy
|
||||
from esm.utils.structure.affine3d import Affine3D
|
||||
from esm.utils.structure.aligner import Aligner
|
||||
from esm.utils.structure.atom_indexer import AtomIndexer
|
||||
from esm.utils.structure.metrics import compute_gdt_ts, compute_lddt_ca
|
||||
from esm.utils.structure.mmcif_parsing import MmcifWrapper, NoProteinError
|
||||
from esm.utils.structure.metrics import (
|
||||
compute_gdt_ts,
|
||||
compute_lddt_ca,
|
||||
)
|
||||
from esm.utils.structure.mmcif_parsing import (
|
||||
MmcifWrapper,
|
||||
NoProteinError,
|
||||
)
|
||||
from esm.utils.structure.protein_chain import (
|
||||
ProteinChain,
|
||||
chain_to_ndarray,
|
||||
|
||||
45
esm/utils/system.py
Normal file
45
esm/utils/system.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import io
|
||||
import subprocess
|
||||
import typing as T
|
||||
from pathlib import Path
|
||||
|
||||
PathLike = T.Union[str, Path]
|
||||
PathOrBuffer = T.Union[PathLike, io.StringIO]
|
||||
|
||||
|
||||
def run_subprocess_with_errorcheck(
|
||||
*popenargs,
|
||||
capture_output: bool = False,
|
||||
quiet: bool = False,
|
||||
env: dict[str, str] | None = None,
|
||||
shell: bool = False,
|
||||
executable: str | None = None,
|
||||
**kws,
|
||||
) -> subprocess.CompletedProcess:
|
||||
"""A command similar to subprocess.run, however the errormessage will
|
||||
contain the stderr when using this function. This makes it significantly
|
||||
easier to diagnose issues.
|
||||
"""
|
||||
try:
|
||||
if capture_output:
|
||||
stdout = subprocess.PIPE
|
||||
elif quiet:
|
||||
stdout = subprocess.DEVNULL
|
||||
else:
|
||||
stdout = None
|
||||
|
||||
p = subprocess.run(
|
||||
*popenargs,
|
||||
stderr=subprocess.PIPE,
|
||||
stdout=stdout,
|
||||
check=True,
|
||||
env=env,
|
||||
shell=shell,
|
||||
executable=executable,
|
||||
**kws,
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise RuntimeError(
|
||||
f"Command failed with errorcode {e.returncode}." f"\n\n{e.stderr.decode()}"
|
||||
)
|
||||
return p
|
||||
@@ -4,7 +4,9 @@ import pygtrie
|
||||
from ipywidgets import widgets
|
||||
|
||||
from esm.sdk.api import FunctionAnnotation
|
||||
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
|
||||
from esm.tokenization.function_tokenizer import (
|
||||
InterProQuantizedTokenizer,
|
||||
)
|
||||
|
||||
TRIE: pygtrie.CharTrie | None = None
|
||||
|
||||
|
||||
@@ -7,11 +7,15 @@ import matplotlib.colors as mcolors
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from esm.sdk.api import ESMProtein
|
||||
from esm.widgets.utils.drawing.draw_category_array import draw_data_array
|
||||
from esm.widgets.utils.drawing.draw_category_array import (
|
||||
draw_data_array,
|
||||
)
|
||||
from esm.widgets.utils.drawing.draw_function_annotations import (
|
||||
draw_function_annotations,
|
||||
)
|
||||
from esm.widgets.utils.drawing.draw_protein_structure import draw_protein_structure
|
||||
from esm.widgets.utils.drawing.draw_protein_structure import (
|
||||
draw_protein_structure,
|
||||
)
|
||||
from esm.widgets.utils.serialization import (
|
||||
create_download_button_from_buffer,
|
||||
protein_to_pdb_buffer,
|
||||
|
||||
@@ -3,9 +3,16 @@ from typing import Any, Callable, Sequence
|
||||
import ipywidgets as widgets
|
||||
|
||||
from esm.utils.structure.protein_chain import ProteinChain
|
||||
from esm.widgets.utils.drawing.colors import hex_to_rgba_tuple, rgba_tuple_to_hex
|
||||
from esm.widgets.utils.drawing.draw_category_array import draw_data_array
|
||||
from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges
|
||||
from esm.widgets.utils.drawing.colors import (
|
||||
hex_to_rgba_tuple,
|
||||
rgba_tuple_to_hex,
|
||||
)
|
||||
from esm.widgets.utils.drawing.draw_category_array import (
|
||||
draw_data_array,
|
||||
)
|
||||
from esm.widgets.utils.parsing import (
|
||||
convert_range_string_to_list_of_ranges,
|
||||
)
|
||||
from esm.widgets.utils.prompting import PromptManager
|
||||
|
||||
|
||||
|
||||
@@ -4,9 +4,16 @@ import ipywidgets as widgets
|
||||
import pydssp
|
||||
|
||||
from esm.utils.structure.protein_chain import ProteinChain
|
||||
from esm.widgets.utils.drawing.colors import hex_to_rgba_tuple, rgba_tuple_to_hex
|
||||
from esm.widgets.utils.drawing.draw_category_array import draw_data_array
|
||||
from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges
|
||||
from esm.widgets.utils.drawing.colors import (
|
||||
hex_to_rgba_tuple,
|
||||
rgba_tuple_to_hex,
|
||||
)
|
||||
from esm.widgets.utils.drawing.draw_category_array import (
|
||||
draw_data_array,
|
||||
)
|
||||
from esm.widgets.utils.parsing import (
|
||||
convert_range_string_to_list_of_ranges,
|
||||
)
|
||||
from esm.widgets.utils.prompting import PromptManager
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,9 @@ from esm.widgets.utils.drawing.colors import (
|
||||
hex_to_rgba_tuple,
|
||||
rgba_tuple_to_rgba_html_string,
|
||||
)
|
||||
from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges
|
||||
from esm.widgets.utils.parsing import (
|
||||
convert_range_string_to_list_of_ranges,
|
||||
)
|
||||
from esm.widgets.utils.prompting import PromptManager
|
||||
|
||||
|
||||
|
||||
@@ -10,8 +10,12 @@ from matplotlib.patches import Rectangle
|
||||
|
||||
from esm.utils.structure.protein_chain import ProteinChain
|
||||
from esm.widgets.utils import indexing
|
||||
from esm.widgets.utils.drawing.draw_protein_structure import draw_protein_structure
|
||||
from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges
|
||||
from esm.widgets.utils.drawing.draw_protein_structure import (
|
||||
draw_protein_structure,
|
||||
)
|
||||
from esm.widgets.utils.parsing import (
|
||||
convert_range_string_to_list_of_ranges,
|
||||
)
|
||||
from esm.widgets.utils.printing import wrapped_print
|
||||
from esm.widgets.utils.prompting import PromptManager
|
||||
|
||||
|
||||
@@ -9,7 +9,10 @@ from matplotlib import colormaps
|
||||
from PIL import Image
|
||||
|
||||
from esm.sdk.api import FunctionAnnotation
|
||||
from esm.utils.function.interpro import InterPro, InterProEntryType
|
||||
from esm.utils.function.interpro import (
|
||||
InterPro,
|
||||
InterProEntryType,
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
||||
@@ -9,7 +9,9 @@ from esm.sdk.api import ESMProtein, FunctionAnnotation
|
||||
from esm.utils import encoding
|
||||
from esm.widgets.utils import indexing
|
||||
from esm.widgets.utils.drawing.colors import rgba_tuple_to_hex
|
||||
from esm.widgets.utils.drawing.draw_category_array import draw_data_array
|
||||
from esm.widgets.utils.drawing.draw_category_array import (
|
||||
draw_data_array,
|
||||
)
|
||||
from esm.widgets.utils.printing import wrapped_print
|
||||
|
||||
|
||||
|
||||
@@ -13,9 +13,13 @@ from esm.sdk.api import (
|
||||
GenerationConfig,
|
||||
)
|
||||
from esm.utils.constants import models
|
||||
from esm.widgets.components.results_visualizer import create_results_visualizer
|
||||
from esm.widgets.components.results_visualizer import (
|
||||
create_results_visualizer,
|
||||
)
|
||||
from esm.widgets.utils.printing import wrapped_print
|
||||
from esm.widgets.utils.serialization import create_download_results_button
|
||||
from esm.widgets.utils.serialization import (
|
||||
create_download_results_button,
|
||||
)
|
||||
|
||||
|
||||
def create_esm3_generation_launcher(
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from ipywidgets import widgets
|
||||
|
||||
from esm.widgets.components.sasa_prompt_selector import create_sasa_prompt_selector
|
||||
from esm.widgets.components.sasa_prompt_selector import (
|
||||
create_sasa_prompt_selector,
|
||||
)
|
||||
from esm.widgets.components.secondary_structure_prompt_selector import (
|
||||
create_secondary_structure_prompt_selector,
|
||||
)
|
||||
|
||||
@@ -4,12 +4,20 @@ from ipywidgets import widgets
|
||||
|
||||
from esm.sdk.api import ESM3InferenceClient, ESMProtein
|
||||
from esm.utils.constants import esm3 as C
|
||||
from esm.widgets.components.function_annotator import create_function_annotator
|
||||
from esm.widgets.components.function_annotator import (
|
||||
create_function_annotator,
|
||||
)
|
||||
from esm.widgets.utils.prompting import PromptManagerCollection
|
||||
from esm.widgets.utils.protein_import import ProteinImporter
|
||||
from esm.widgets.views.esm3_generation_launcher import create_esm3_generation_launcher
|
||||
from esm.widgets.views.esm3_prompt_preview import create_esm3_prompt_preview
|
||||
from esm.widgets.views.esm3_prompt_selector import create_esm3_prompt_selector
|
||||
from esm.widgets.views.esm3_generation_launcher import (
|
||||
create_esm3_generation_launcher,
|
||||
)
|
||||
from esm.widgets.views.esm3_prompt_preview import (
|
||||
create_esm3_prompt_preview,
|
||||
)
|
||||
from esm.widgets.views.esm3_prompt_selector import (
|
||||
create_esm3_prompt_selector,
|
||||
)
|
||||
|
||||
|
||||
def create_generation_ui(
|
||||
|
||||
@@ -6,7 +6,9 @@ from esm.sdk.api import (
|
||||
ESMProteinError,
|
||||
GenerationConfig,
|
||||
)
|
||||
from esm.widgets.components.results_visualizer import create_results_visualizer
|
||||
from esm.widgets.components.results_visualizer import (
|
||||
create_results_visualizer,
|
||||
)
|
||||
from esm.widgets.utils.printing import wrapped_print
|
||||
from esm.widgets.utils.protein_import import ProteinImporter
|
||||
|
||||
|
||||
@@ -4,7 +4,10 @@ from textwrap import dedent
|
||||
|
||||
from ipywidgets import widgets
|
||||
|
||||
from esm.widgets.utils.clients import get_forge_client, get_local_client
|
||||
from esm.widgets.utils.clients import (
|
||||
get_forge_client,
|
||||
get_local_client,
|
||||
)
|
||||
from esm.widgets.utils.types import ClientInitContainer
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,9 @@ from esm.sdk.api import (
|
||||
ESMProteinError,
|
||||
GenerationConfig,
|
||||
)
|
||||
from esm.widgets.components.results_visualizer import create_results_visualizer
|
||||
from esm.widgets.components.results_visualizer import (
|
||||
create_results_visualizer,
|
||||
)
|
||||
from esm.widgets.utils.printing import wrapped_print
|
||||
from esm.widgets.utils.protein_import import ProteinImporter
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "esm"
|
||||
version = "3.2.2"
|
||||
version = "3.2.2.post1"
|
||||
description = "EvolutionaryScale open model repository"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12,<3.13"
|
||||
@@ -45,7 +45,6 @@ dependencies = [
|
||||
"pygtrie",
|
||||
"dna_features_viewer",
|
||||
]
|
||||
|
||||
# Pytest
|
||||
[tool.pytest.ini_options]
|
||||
addopts = """
|
||||
|
||||
@@ -1,3 +1,2 @@
|
||||
esm
|
||||
esm >=3.2.1post1,<4.0.0
|
||||
pytest
|
||||
httpx # TODO(williamxi): Remove this after the esm repo is fixed
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
import torch
|
||||
from esm.sdk import client # pyright: ignore
|
||||
from esm.sdk.api import ( # pyright: ignore
|
||||
ESMProtein,
|
||||
@@ -37,6 +37,7 @@ def test_oss_esm3_client():
|
||||
logits_config = LogitsConfig(sequence=True, return_embeddings=True)
|
||||
result = esm3_client.logits(input=encoded_protein, config=logits_config)
|
||||
assert isinstance(result, LogitsOutput)
|
||||
assert isinstance(result.logits.sequence, torch.Tensor)
|
||||
|
||||
sampling_config = SamplingConfig(sequence=SamplingTrackConfig(temperature=0.1))
|
||||
result = esm3_client.forward_and_sample(
|
||||
@@ -53,7 +54,7 @@ def test_oss_esm3_client():
|
||||
def test_oss_esmc_client():
|
||||
assert URL is not None
|
||||
|
||||
sequence = "MALWMRLLPLLALLALAVUUPDPAAA"
|
||||
sequence = "MALWMRLLPLLALLALAVPDPAAA"
|
||||
model = "esmc-300m-2024-12"
|
||||
esmc_client = client(model=model, url=URL, token=API_TOKEN)
|
||||
|
||||
@@ -69,13 +70,14 @@ def test_oss_esmc_client():
|
||||
)
|
||||
result = esmc_client.logits(input=encoded_protein, config=logits_config)
|
||||
assert isinstance(result, LogitsOutput)
|
||||
assert isinstance(result.logits.sequence, torch.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.sdk
|
||||
def test_oss_sequence_structure_forge_inference_client():
|
||||
assert URL is not None
|
||||
|
||||
sequence = "MALWMRLLPLLALLALAVUUPDPAAA"
|
||||
sequence = "MALWMRLLPLLALLALAVPDPAAA"
|
||||
model = "esm3-small-2024-03"
|
||||
client = SequenceStructureForgeInferenceClient(
|
||||
model=model, url=URL, token=API_TOKEN
|
||||
|
||||
Reference in New Issue
Block a user