Sync over internal code to open source (#266)

Co-authored-by: Steve Chan <>
This commit is contained in:
Ishaan Mathur
2025-08-18 17:34:56 -04:00
committed by GitHub
parent 9abce48184
commit 95239e2d19
66 changed files with 8994 additions and 1037 deletions

View File

@@ -18,7 +18,7 @@ concurrency:
jobs:
test-precommit:
runs-on: ubuntu-22.04-16core
runs-on: ubuntu-24.04
steps:
- name: Checkout code
@@ -32,20 +32,11 @@ jobs:
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
- name: Check formatting and typing
run: |
set -e
# pyright seems to do something weird at initialization that causes it to error out
# We can ignore the first invocation here.
pyright esm/__init__.py || true
pre-commit install
env NODE_OPTIONS="--max-old-space-size=16384" pre-commit run --all-files --show-diff-on-failure
[ -z "$(git status --porcelain)" ] && true || (echo "❌❌❌ pre-commit hook failed! A few files changed ❌❌❌]" && git status --porcelain && false)
git reset --hard HEAD # test without the pre-commit changes
shell: pixi run bash -e {0}
run: pixi run lint-all
test-esm:
runs-on: ubuntu-22.04-16core
runs-on: ubuntu-24.04
steps:
- name: Checkout code
@@ -59,20 +50,18 @@ jobs:
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
- name: Run tests
run: |
set -o pipefail
pytest -v --junitxml=pytest.xml tests/ | tee pytest-coverage.txt
shell: pixi run bash -e {0}
run: pixi run cov-test
- name: Run Docker tests
env:
DOCKER_TAG: ${{ github.sha }}
FORGE_URL: https://forge.evolutionaryscale.ai/
ESM3_FORGE_TOKEN: ${{ secrets.ESM3_FORGE_TOKEN }}
run: |
set -e
cd tests
make build-oss-ci
make start-docker-oss URL=${{ env.FORGE_URL }} DOCKER_TAG=${{ env.DOCKER_TAG }} ESM3_FORGE_TOKEN=${{ secrets.ESM3_FORGE_TOKEN }}
make start-docker-oss URL=${{ env.FORGE_URL }} DOCKER_TAG=${{ env.DOCKER_TAG }} ESM3_FORGE_TOKEN=${{ env.ESM3_FORGE_TOKEN }}
shell: pixi run bash -e {0}
- name: cleanup docker containers if they're hanging

1
.gitignore vendored
View File

@@ -2,3 +2,4 @@ esm.egg-info
# pixi environments
.pixi
*.egg-info
*.pyc

33
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,33 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
exclude: (fasta|pdb|cif|mds|json)$
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
exclude: pixi.lock
- id: check-merge-conflict
- repo: https://github.com/seddonym/import-linter
rev: v1.12.1
hooks:
- id: import-linter
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.3
hooks:
- id: ruff # linter
args: [ --fix ]
- id: ruff-format # formatter
types_or: [python, jupyter]
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.399
hooks:
- id: pyright
name: pyright
entry: pyright
language: system
types: [python]
pass_filenames: true # For speed, we only check the files that are changed

View File

@@ -9,6 +9,7 @@
- [Installation ](#installation-)
- [Available Models](#available-models-)
- [ESM 3](#esm-3-)
- [Quickstart for ESM3 Open](#esm3-quickstart-)
- [ESM3 98B via Forge API](#esm3-forge)
@@ -33,6 +34,31 @@ To get started with ESM, install the python library using pip:
pip install esm
```
## Available Models <a name="available-models"></a>
### ESM 3 Family
| Model | Model Size | Release Date | Note |
|-------|------------|--------------|------|
| **Flagship Models** | | | Most users will be interested in using one of these models. |
| esm3-large-2024-03 | 98B | 2024-03 | |
| esm3-medium-2024-08 | 7B | 2024-08 | |
| esm3-small-2024-08 | 1.4B | 2024-08 | |
| **Published Models** | | | These models were used to generate all of the results in the ESM3 paper and are provided to facilitate reproducibility. |
| esm3-large-2024-03 | 98B | 2024-03 | |
| esm3-medium-2024-03 | 7B | 2024-03 | |
| esm3-small-2024-03 | 1.4B | 2024-03 | |
| **Experimental Models** | | | These models are provided for early use by researchers and are still under development. |
| esm3-medium-multimer-2024-09 | 7B | 2024-09 | |
### ESM C Models
| Model | Model Size | Number of Layers | Release Date |
|-------|------------|------------------|--------------|
| esmc-6b-2024-12 | 6B | 80 | 2024-12 |
| esmc-600m-2024-12 | 600M | 36 | 2024-12 |
| esmc-300m-2024-12 | 300M | 30 | 2024-12 |
## ESM 3 <a name="esm3"></a>
[ESM3](https://www.evolutionaryscale.ai/papers/esm3-simulating-500-million-years-of-evolution-with-a-language-model) is a frontier generative model for biology, able to jointly reason across three fundamental biological properties of proteins: sequence, structure, and function. These three data modalities are represented as tracks of discrete tokens at the input and output of ESM3. You can present the model with a combination of partial inputs across the tracks, and ESM3 will provide output predictions for all the tracks.

View File

@@ -38,6 +38,7 @@
"\n",
"!pip install py3Dmol\n",
"import py3Dmol\n",
"\n",
"from esm.models.esm3 import ESM3\n",
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
"from esm.utils.structure.protein_chain import ProteinChain"

View File

@@ -13,9 +13,7 @@ from esm.tokenization import get_esm3_model_tokenizers
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer as EsmFunctionTokenizer,
)
from esm.tokenization.sequence_tokenizer import (
EsmSequenceTokenizer,
)
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
from esm.utils.structure.protein_chain import ProteinChain
from esm.utils.types import FunctionAnnotation

View File

@@ -72,6 +72,7 @@
"outputs": [],
"source": [
"from biotite.database import rcsb\n",
"\n",
"from esm.sdk.api import ESMProtein\n",
"from esm.utils.structure.protein_chain import ProteinChain\n",
"from esm.utils.types import FunctionAnnotation\n",
@@ -496,9 +497,10 @@
"# Functions for visualizing InterPro function annotations\n",
"\n",
"from dna_features_viewer import GraphicFeature, GraphicRecord\n",
"from esm.utils.function.interpro import InterPro, InterProEntryType\n",
"from matplotlib import colormaps\n",
"\n",
"from esm.utils.function.interpro import InterPro, InterProEntryType\n",
"\n",
"\n",
"def visualize_function_annotations(\n",
" annotations: list[FunctionAnnotation],\n",

View File

@@ -64,6 +64,7 @@
"import matplotlib.pyplot as pl\n",
"import py3Dmol\n",
"import torch\n",
"\n",
"from esm.sdk import client\n",
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
"from esm.utils.structure.protein_chain import ProteinChain"

View File

@@ -36,6 +36,7 @@
"\n",
"!pip install py3Dmol\n",
"import py3Dmol\n",
"\n",
"from esm.sdk import client\n",
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
"from esm.utils.structure.protein_chain import ProteinChain"

View File

@@ -14,13 +14,13 @@
"3. Minimize a biophysical energy function\n",
"4. Use experimental screening data to guide designs with a regression model\n",
"\n",
"As long as your scoring function takes a protein as input and outputs a single score, you can use it to guide designs. To accomplish this, we use an implementation of derivative-free guidance inspired by Soft Value-Based Decoding described in [Li, et al 2024](https://arxiv.org/abs/2408.08252).\n",
"As long as your scoring function takes a protein as input and outputs a single score, you can use it to guide designs. To accomplish this, we use an implementation of derivative-free guidance inspired by Soft Value-Based Decoding described in [Li, et al 2024](https://arxiv.org/abs/2408.08252) and constrained optimization using the Modified Differential Method of Multipliers from [Platt & Barr 1987](https://proceedings.neurips.cc/paper_files/paper/1987/file/a1126573153ad7e9f44ba80e99316482-Paper.pdf)\n",
"\n",
"In this notebook we will walk through a few examples to illustrate how to use guided generation. \n",
"\n",
"1. Guide towards high pTM for improved generation quality\n",
"2. Generate a protein with no cysteine (C) residues\n",
"3. Maximize protein globularity by minimizing the radius of gyration\n",
"3. Maximize protein globularity by minimizing the radius of gyration, while keeping pTM high\n",
"\n"
]
},
@@ -49,6 +49,7 @@
"source": [
"import biotite.structure as bs\n",
"import py3Dmol\n",
"\n",
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
"from esm.sdk.experimental import ESM3GuidedDecoding, GuidedDecodingScoringFunction"
]
@@ -269,6 +270,11 @@
"metadata": {},
"outputs": [],
"source": [
"# Start from a fully masked protein\n",
"PROTEIN_LENGTH = 256\n",
"starting_protein = ESMProtein(sequence=\"_\" * PROTEIN_LENGTH)\n",
"\n",
"# Call guided_generate\n",
"no_cysteine_protein = no_cysteine_guided_decoding.guided_generate(\n",
" protein=starting_protein,\n",
" num_decoding_steps=len(starting_protein) // 8,\n",
@@ -302,7 +308,20 @@
"source": [
"## Maximize Globularity\n",
"\n",
"We use the radius of gyration as a proxy to maximize globularity, we also encourage generations to have high pTM"
"We use the radius of gyration as a proxy to maximize globularity, and we will also encourage generations to have high pTM by using constraints"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from esm.sdk.experimental import (\n",
" ConstraintType,\n",
" ESM3GuidedDecodingWithConstraints,\n",
" GenerationConstraint,\n",
")"
]
},
{
@@ -313,12 +332,11 @@
"source": [
"class RadiousOfGyrationScoringFunction(GuidedDecodingScoringFunction):\n",
" def __call__(self, protein: ESMProtein) -> float:\n",
" # Use the negative radius of gyration as the score to maximize\n",
" score = -1 * self.radius_of_gyration(protein)\n",
"\n",
" assert protein.ptm is not None, \"Protein must have pTM scores to be scored\"\n",
" if protein.ptm < 0.5:\n",
" # Penalize proteins with low pTM scores\n",
" score = score * 2\n",
" # Re-scale the score to be in a similar magnitude as pTM\n",
" score = score / 100.0\n",
"\n",
" return score\n",
"\n",
@@ -335,8 +353,19 @@
"metadata": {},
"outputs": [],
"source": [
"radius_guided_decoding = ESM3GuidedDecoding(\n",
" client=model, scoring_function=RadiousOfGyrationScoringFunction()\n",
"# Constrain generation to have pTM > 0.75\n",
"ptm_constraint = GenerationConstraint(\n",
" scoring_function=PTMScoringFunction(),\n",
" constraint_type=ConstraintType.GREATER_EQUAL,\n",
" value=0.75,\n",
")\n",
"\n",
"radius_guided_decoding = ESM3GuidedDecodingWithConstraints(\n",
" client=model,\n",
" scoring_function=RadiousOfGyrationScoringFunction(),\n",
" constraints=[ptm_constraint], # Add list of constraints\n",
" damping=1.0, # Damping factor for the MMDM algorithm\n",
" learning_rate=10.0, # Learning rate for the MMDM algorithm\n",
")"
]
},
@@ -346,6 +375,11 @@
"metadata": {},
"outputs": [],
"source": [
"# Start from a fully masked protein\n",
"PROTEIN_LENGTH = 256\n",
"starting_protein = ESMProtein(sequence=\"_\" * PROTEIN_LENGTH)\n",
"\n",
"# Call guided_generate\n",
"radius_guided_protein = radius_guided_decoding.guided_generate(\n",
" protein=starting_protein,\n",
" num_decoding_steps=len(starting_protein) // 8,\n",
@@ -359,11 +393,32 @@
"metadata": {},
"outputs": [],
"source": [
"# Visualize the trajectory of the constrained generation\n",
"radius_guided_decoding.visualize_latest_trajectory()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Visualize the generated protein\n",
"view = py3Dmol.view(width=800, height=400)\n",
"view.addModel(radius_guided_protein.to_pdb_string(), \"pdb\")\n",
"view.setStyle({\"cartoon\": {\"color\": \"spectrum\"}})\n",
"view.zoomTo()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Check pTM\n",
"radius_guided_protein.ptm"
]
}
],
"metadata": {

View File

@@ -1,2 +1 @@
__version__ = "3.2.1"

View File

@@ -5,15 +5,12 @@ import torch
import torch.nn.functional as F
from torch import nn
from esm.layers.rotary import (
RotaryEmbedding,
TritonRotaryEmbedding,
)
from esm.layers.rotary import RotaryEmbedding, TritonRotaryEmbedding
try:
from flash_attn import flash_attn_varlen_qkvpacked_func # type:ignore
except ImportError:
flash_attn_varlen_func = None
from flash_attn import flash_attn_varlen_qkvpacked_func
except (ImportError, RuntimeError):
flash_attn_varlen_qkvpacked_func = None
class MultiHeadAttention(nn.Module):
@@ -117,7 +114,7 @@ class FlashMultiHeadAttention(MultiHeadAttention):
)
qkv_N3HD = self.rotary(qkv_N3HD, cu_seqlens, max_seqlen)
context_NHD = flash_attn_varlen_qkvpacked_func(
context_NHD = flash_attn_varlen_qkvpacked_func( # type: ignore
qkv_N3HD, cu_seqlens, max_seqlen, softmax_scale=self.d_head**-0.5
)
context_ND = einops.rearrange(context_NHD, "n h d -> n (h d)")

View File

@@ -2,13 +2,8 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from esm.layers.attention import (
FlashMultiHeadAttention,
MultiHeadAttention,
)
from esm.layers.geom_attention import (
GeometricReasoningOriginalImpl,
)
from esm.layers.attention import FlashMultiHeadAttention, MultiHeadAttention
from esm.layers.geom_attention import GeometricReasoningOriginalImpl
from esm.utils.structure.affine3d import Affine3D

View File

@@ -2,10 +2,7 @@ import torch
import torch.nn as nn
from esm.utils.constants.physics import BB_COORDINATES
from esm.utils.structure.affine3d import (
Affine3D,
RotationMatrix,
)
from esm.utils.structure.affine3d import Affine3D, RotationMatrix
class Dim6RotStructureHead(nn.Module):

View File

@@ -13,10 +13,7 @@ from attr import dataclass
from esm.layers.regression_head import RegressionHead
from esm.layers.transformer_stack import TransformerStack
from esm.models.function_decoder import FunctionTokenDecoder
from esm.models.vqvae import (
StructureTokenDecoder,
StructureTokenEncoder,
)
from esm.models.vqvae import StructureTokenDecoder, StructureTokenEncoder
from esm.sdk.api import (
ESM3InferenceClient,
ESMProtein,
@@ -32,10 +29,7 @@ from esm.sdk.api import (
from esm.tokenization import TokenizerCollectionProtocol
from esm.utils import encoding
from esm.utils.constants import esm3 as C
from esm.utils.constants.models import (
ESM3_OPEN_SMALL,
normalize_model_name,
)
from esm.utils.constants.models import ESM3_OPEN_SMALL, normalize_model_name
from esm.utils.decoding import decode_protein_tensor
from esm.utils.generation import (
_batch_forward,
@@ -50,9 +44,7 @@ from esm.utils.sampling import (
get_default_sampling_config,
validate_sampling_config,
)
from esm.utils.structure.affine3d import (
build_affine3d_from_coordinates,
)
from esm.utils.structure.affine3d import build_affine3d_from_coordinates
@dataclass

View File

@@ -12,9 +12,7 @@ from cloudpathlib import AnyPath
from esm.layers.regression_head import RegressionHead
from esm.layers.transformer_stack import TransformerStack
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer,
)
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
from esm.utils.constants import esm3 as C
from esm.utils.misc import merge_annotations, merge_ranges
from esm.utils.types import FunctionAnnotation

View File

@@ -7,10 +7,7 @@ from esm.layers.structure_proj import Dim6RotStructureHead
from esm.layers.transformer_stack import TransformerStack
from esm.utils.constants import esm3 as C
from esm.utils.misc import knn_graph
from esm.utils.structure.affine3d import (
Affine3D,
build_affine3d_from_coordinates,
)
from esm.utils.structure.affine3d import Affine3D, build_affine3d_from_coordinates
from esm.utils.structure.predicted_aligned_error import (
compute_predicted_aligned_error,
compute_tm,

View File

@@ -6,14 +6,8 @@ import torch.nn as nn
from esm.models.esm3 import ESM3
from esm.models.esmc import ESMC
from esm.models.function_decoder import FunctionTokenDecoder
from esm.models.vqvae import (
StructureTokenDecoder,
StructureTokenEncoder,
)
from esm.tokenization import (
get_esm3_model_tokenizers,
get_esmc_model_tokenizers,
)
from esm.models.vqvae import StructureTokenDecoder, StructureTokenEncoder
from esm.tokenization import get_esm3_model_tokenizers, get_esmc_model_tokenizers
from esm.utils.constants.esm3 import data_root
from esm.utils.constants.models import (
ESM3_FUNCTION_DECODER_V0,

View File

@@ -2,27 +2,19 @@ from __future__ import annotations
from abc import ABC
from copy import deepcopy
from typing import List, Sequence
from typing import Sequence
import attr
import torch
from attr import asdict, define
import esm.utils.constants.api as C
from esm.tokenization import (
TokenizerCollectionProtocol,
get_esm3_model_tokenizers,
)
from esm.tokenization import TokenizerCollectionProtocol, get_esm3_model_tokenizers
from esm.utils import encoding
from esm.utils.constants.models import ESM3_OPEN_SMALL
from esm.utils.misc import (
get_chainbreak_boundaries_from_sequence,
)
from esm.utils.misc import get_chainbreak_boundaries_from_sequence
from esm.utils.structure.protein_chain import ProteinChain
from esm.utils.structure.protein_complex import (
SINGLE_LETTER_CHAIN_IDS,
ProteinComplex,
)
from esm.utils.structure.protein_complex import SINGLE_LETTER_CHAIN_IDS, ProteinComplex
from esm.utils.types import FunctionAnnotation, PathOrBuffer
@@ -43,7 +35,6 @@ class ESMProtein(ProteinType):
plddt: torch.Tensor | None = None
ptm: torch.Tensor | None = None
# When calling EvolutionaryScale API, use this flag to disclose any
# sequences that may potentially have concerns.
# Such sequences may not go through standard safety filter for approved users.
@@ -79,12 +70,9 @@ class ESMProtein(ProteinType):
def from_protein_chain(
cls, protein_chain: ProteinChain, with_annotations: bool = False
) -> ESMProtein:
# By default, we don't annotate with DSSP / SASA, which are expensive.
# If mkdssp is installed, we can annotate with a flag.
if with_annotations:
return ESMProtein(
sequence=protein_chain.sequence,
secondary_structure=protein_chain.dssp().tolist(),
sasa=protein_chain.sasa().tolist(),
function_annotations=None,
coordinates=torch.tensor(protein_chain.atom37_positions),
@@ -123,7 +111,8 @@ class ESMProtein(ProteinType):
protein_complex.to_pdb(pdb_path)
def to_pdb_string(self) -> str:
protein_chain = self.to_protein_chain()
# Note: This was modified to match .to_pdb() behavior. We can revisit this at some point
protein_chain = self.to_protein_complex().infer_oxygen()
return protein_chain.to_pdb_string()
def to_protein_chain(self) -> ProteinChain:
@@ -172,6 +161,7 @@ class ESMProtein(ProteinType):
if gt_chains is not None
else SINGLE_LETTER_CHAIN_IDS[i],
entity_id=gt_chains[i].entity_id if gt_chains is not None else None,
confidence=self.plddt[start:end] if self.plddt is not None else None,
)
pred_chains.append(pred_chain)
return ProteinComplex.from_chains(pred_chains)
@@ -321,8 +311,6 @@ class InverseFoldingConfig:
temperature: float = 1.0
## Low Level Endpoint Types
@define
class SamplingTrackConfig:
@@ -382,22 +370,23 @@ class LogitsConfig:
# Embeddings.
return_embeddings: bool = False
return_hidden_states: bool = False
return_mean_embedding: bool = False
return_mean_hidden_states: bool = False
ith_hidden_layer: int = -1
@define
class LogitsOutput:
logits: ForwardTrackData | None = None
embeddings: torch.Tensor | None = None
mean_embedding: torch.Tensor | None = None
# Residue annotations is multi-hot, so deserves special treatment
# It's not a categorical distribution, but instead a bernoulli, so
# softmax across the last dimension is _wrong_
residue_annotation_logits: torch.Tensor | None = None
hidden_states: torch.Tensor | None = None
mean_hidden_state: torch.Tensor | None = None
@define

View File

@@ -70,11 +70,13 @@ class _BaseForgeInferenceClient:
def prepare_request(
self,
request: dict[str, Any],
potential_sequence_of_concern: bool = False,
potential_sequence_of_concern: bool | None = None,
return_bytes: bool = False,
headers: dict[str, str] = {},
) -> tuple[dict[str, Any], dict[str, str]]:
request["potential_sequence_of_concern"] = potential_sequence_of_concern
if potential_sequence_of_concern is not None:
request["potential_sequence_of_concern"] = potential_sequence_of_concern
headers = {**self.headers, **headers}
if return_bytes:
headers["return-bytes"] = "true"
@@ -103,42 +105,58 @@ class _BaseForgeInferenceClient:
self,
endpoint,
request,
potential_sequence_of_concern: bool = False,
potential_sequence_of_concern: bool | None = None,
params: dict[str, Any] = {},
headers: dict[str, str] = {},
return_bytes: bool = False,
):
request, headers = self.prepare_request(
request, potential_sequence_of_concern, return_bytes, headers
)
response = await self.async_client.post(
url=urljoin(self.url, f"/api/v1/{endpoint}"),
json=request,
params=params,
headers=headers,
timeout=self.request_timeout,
)
data = self.prepare_data(response, endpoint)
return data
try:
request, headers = self.prepare_request(
request, potential_sequence_of_concern, return_bytes, headers
)
response = await self.async_client.post(
url=urljoin(self.url, f"/api/v1/{endpoint}"),
json=request,
params=params,
headers=headers,
timeout=self.request_timeout,
)
data = self.prepare_data(response, endpoint)
return data
except ESMProteinError as e:
raise e
except Exception as e:
raise ESMProteinError(
error_code=500,
error_msg=f"Failed to submit request to {endpoint}. Error: {e}",
)
def _post(
self,
endpoint,
request,
potential_sequence_of_concern: bool = False,
potential_sequence_of_concern: bool | None = None,
params: dict[str, Any] = {},
headers: dict[str, str] = {},
return_bytes: bool = False,
):
request, headers = self.prepare_request(
request, potential_sequence_of_concern, return_bytes, headers
)
response = self.client.post(
url=urljoin(self.url, f"/api/v1/{endpoint}"),
json=request,
params=params,
headers=headers,
timeout=self.request_timeout,
)
data = self.prepare_data(response, endpoint)
return data
try:
request, headers = self.prepare_request(
request, potential_sequence_of_concern, return_bytes, headers
)
response = self.client.post(
url=urljoin(self.url, f"/api/v1/{endpoint}"),
json=request,
params=params,
headers=headers,
timeout=self.request_timeout,
)
data = self.prepare_data(response, endpoint)
return data
except ESMProteinError as e:
raise e
except Exception as e:
raise ESMProteinError(
error_code=500,
error_msg=f"Failed to submit request to {endpoint}. Error: {e}",
)

View File

@@ -0,0 +1,14 @@
from .constrained_generation import (
ConstraintType,
ESM3GuidedDecodingWithConstraints,
GenerationConstraint,
)
from .guided_generation import ESM3GuidedDecoding, GuidedDecodingScoringFunction
__all__ = [
"ConstraintType",
"ESM3GuidedDecodingWithConstraints",
"GenerationConstraint",
"ESM3GuidedDecoding",
"GuidedDecodingScoringFunction",
]

View File

@@ -0,0 +1,324 @@
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
from typing import List
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.collections import LineCollection
from matplotlib.colors import Normalize
from tqdm import tqdm
from esm.sdk import batch_executor
from esm.sdk.api import (
ESM3InferenceClient,
ESMProtein,
ESMProteinError,
ESMProteinTensor,
SamplingConfig,
SamplingTrackConfig,
)
from esm.sdk.experimental.guided_generation import (
ESM3GuidedDecoding,
GuidedDecodingScoringFunction,
)
class ConstraintType(Enum):
GREATER_EQUAL = "greater_equal" # f(x) ≥ threshold
LESS_EQUAL = "less_equal" # f(x) ≤ threshold
EQUAL = "equal" # f(x) = threshold
@dataclass(slots=True)
class GenerationConstraint:
"""
A single inequality or equality constraint.
Parameters
----------
scoring_function
Maps a protein ➜ real value (e.g. pTM, length, …).
value
Target value for inequality or equality constraint.
constraint_type
Type of constraint to apply.
- GREATER_EQUAL: f(x) ≥ value
- LESS_EQUAL: f(x) ≤ value
- EQUAL: f(x) = value (equality)
"""
scoring_function: GuidedDecodingScoringFunction
value: float
constraint_type: ConstraintType = ConstraintType.GREATER_EQUAL
lambda_: float = field(default=0.0, init=False) # dual variable (MDMM)
def g(self, x: float) -> float:
"""
Canonical form:
• inequalities → g(x) ≤ 0
• equalities → h(x) (we still return a scalar, no ≤)
"""
if self.constraint_type is ConstraintType.GREATER_EQUAL:
return self.value - x
if self.constraint_type is ConstraintType.LESS_EQUAL:
return x - self.value
# equality: h(x) = x - value (we will *not* project λ)
return x - self.value
def update_lambda(self, g: float, eta: float, gamma: float) -> None:
"""
Update the dual variable λ according to the MDMM update rule.
"""
if self.constraint_type is ConstraintType.EQUAL:
self.lambda_ += eta * g # no projection for equality constraints
else:
self.lambda_ = max(0.0, self.lambda_ + eta * g)
def copy(self) -> GenerationConstraint:
"""
Create a copy of this constraint.
"""
c = GenerationConstraint(
scoring_function=self.scoring_function,
value=self.value,
constraint_type=self.constraint_type,
)
c.lambda_ = self.lambda_ # copy the dual variable
return c
class ESM3GuidedDecodingWithConstraints(ESM3GuidedDecoding):
"""
Derivative-free guided decoding with constraints.
Uses the Modified Differential Method of Multipliers (MDMM) to
guarantee convergence to the constrained optimum without
hand-tuning penalty weights.
References:
[1] Platt, John, and Alan Barr. "Constrained differential optimization." Neural Information Processing Systems. 1987.
[2] https://www.engraved.blog/how-we-can-make-machine-learning-algorithms-tunable/
"""
def __init__(
self,
client: ESM3InferenceClient,
scoring_function: GuidedDecodingScoringFunction,
constraints: GenerationConstraint | list[GenerationConstraint],
*,
damping: float = 10.0,
learning_rate: float = 1.0,
):
super().__init__(client, scoring_function)
if isinstance(constraints, GenerationConstraint):
constraints = [constraints]
self.constraints = [c.copy() for c in constraints]
self.gamma = float(damping)
self.eta = float(learning_rate)
self.recorder: TrajectoryRecorder | None = None
def guided_generate(
self,
protein: ESMProtein,
num_decoding_steps: int,
num_samples_per_step: int,
denoised_prediction_temperature: float = 0.0,
track: str = "sequence",
verbose: bool = True,
) -> ESMProtein:
# Reset the trajectory recorder
self.recorder = TrajectoryRecorder()
protein_tensor = self.client.encode(protein)
assert not isinstance(protein_tensor, ESMProteinError)
if track == "structure":
protein_tensor = self.maybe_add_default_structure_tokens(protein_tensor)
n_masked = self.get_number_of_masked_positions(protein_tensor, track=track)
n_unmask = n_masked // num_decoding_steps
best_reward = float("-inf")
if verbose:
pbar = tqdm(range(num_decoding_steps), desc="S=-inf λ=0.00")
else:
pbar = range(num_decoding_steps)
for step in pbar:
# Last iteration: unmask whatever is left
if step == num_decoding_steps - 1:
n_unmask = self.get_number_of_masked_positions(
protein_tensor, track=track
)
# ---------- propose & evaluate M samples (parallel-safe) ---- #
def _propose_and_eval(pt: ESMProteinTensor):
new_pt = self.randomly_unmask_positions(pt, n_unmask, track=track)
reward, g_val, raw_vals = self._score_and_constraints(
new_pt, denoised_prediction_temperature
)
return new_pt, reward, g_val, raw_vals
if self._use_batch_executor:
with batch_executor(show_progress=False) as ex:
results = ex.execute_batch(
user_func=_propose_and_eval,
pt=[protein_tensor] * num_samples_per_step,
)
if isinstance(results, Exception):
raise results
samples, rewards, gh_lists, val_lists = zip(*results) # type: ignore
else:
samples, rewards, gh_lists, val_lists = [], [], [], []
for _ in range(num_samples_per_step):
s, r, g, c = _propose_and_eval(protein_tensor)
samples.append(s)
rewards.append(r)
gh_lists.append(g)
val_lists.append(c)
# -------- compute MDMM lagrangian for each sample -----------
lags = [self._lagrangian(r, ghs) for r, ghs in zip(rewards, gh_lists)]
best_idx = int(torch.tensor(lags).argmin())
protein_tensor = samples[best_idx]
best_reward = rewards[best_idx]
best_g_vals = gh_lists[best_idx]
# -------- dual updates (MDMM) -----------------
for g, c in zip(best_g_vals, self.constraints):
c.update_lambda(g, self.eta, self.gamma)
self.recorder.log(
step=step,
reward=best_reward,
g_list=best_g_vals,
lambda_list=[c.lambda_ for c in self.constraints],
)
if verbose and isinstance(pbar, tqdm):
lam_str = ", ".join(
f"λ_{i}={c.lambda_:.2f}" for i, c in enumerate(self.constraints)
)
pbar.set_description(f"S={best_reward:+.3f} {lam_str}")
final = self.client.forward_and_sample(
protein_tensor,
sampling_configuration=SamplingConfig(
sequence=SamplingTrackConfig(temperature=0.0),
structure=SamplingTrackConfig(temperature=0.0),
),
)
assert not isinstance(final, ESMProteinError)
decoded = self.client.decode(final.protein_tensor)
assert not isinstance(decoded, ESMProteinError)
return decoded
def visualize_latest_trajectory(
self, constraint_idx: int = 0, cmap: str = "viridis"
) -> None:
"""
Visualise the trajectory of the latest optimisation run.
If you optimise multiple constraints, pick which one to plot via `constraint_idx`.
"""
if not self.recorder:
raise ValueError("No trajectory recorder available.")
steps, g_vals, rewards = self.recorder.as_arrays(constraint_idx)
self.recorder.plot_line(constraint_idx=constraint_idx, cmap=cmap)
def _score_and_constraints(
self, pt: ESMProteinTensor, temp: float
) -> tuple[float, list[float], list[float]]:
protein = self.predict_denoised(pt, temperature=temp)
reward = self.scoring_function(protein)
vals, ghs = [], []
for c in self.constraints:
val = c.scoring_function(protein)
vals.append(val)
ghs.append(c.g(val))
return reward, ghs, vals
def _lagrangian(self, reward: float, g_vals: list[float]) -> float:
"""
MDMM L(x, λ) = -reward + Σ_i (λ_i - γ g_i) * g_i
(reward is to be maximised ⇒ we minimise -reward)
"""
lag = -reward
for g, c in zip(g_vals, self.constraints):
lag += (c.lambda_ - self.gamma * g) * g
return lag
@dataclass
class TrajectoryRecorder:
steps: List[int] = field(default_factory=list)
rewards: List[float] = field(default_factory=list)
g_vals: List[List[float]] = field(
default_factory=list
) # each step → list of constraints
lambdas: List[List[float]] = field(default_factory=list) # each step → list of λ s
def log(
self, step: int, reward: float, g_list: list[float], lambda_list: list[float]
) -> None:
"""Append one optimisation step to the trajectory."""
self.steps.append(step)
self.rewards.append(reward)
self.g_vals.append(list(g_list))
self.lambdas.append(list(lambda_list))
def as_arrays(self, constraint_idx: int = 0):
"""
Return numpy arrays suitable for plotting.
If you optimise multiple constraints, pick which one to plot via `constraint_idx`.
"""
return (
np.asarray(self.steps),
np.asarray([g[constraint_idx] for g in self.g_vals]),
np.asarray(self.rewards),
)
def plot_line(self, constraint_idx: int = 0, cmap: str = "viridis"):
"""
Continuous line with markers and a colour-gradient that follows the optimisation step.
"""
steps, x_vals, y_vals = self.as_arrays(constraint_idx)
# build coloured line segments
points = np.column_stack([x_vals, y_vals])
segments = np.concatenate([points[:-1, None, :], points[1:, None, :]], axis=1)
norm = Normalize(vmin=steps.min(), vmax=steps.max())
lc = LineCollection(segments, cmap=cmap, norm=norm, linewidth=2) # type: ignore
lc.set_array(steps)
fig, ax = plt.subplots()
ax.add_collection(lc)
ax.scatter(
x_vals,
y_vals,
c=steps,
cmap=cmap,
norm=norm,
marker="o",
edgecolor="k",
zorder=3,
)
ax.axvline(0, linestyle="--", color="grey")
ax.set_xlabel("constraint value g(x) (≤ 0 is feasible)")
ax.set_ylabel("reward R(x)")
ax.set_title("Trajectory in constraintreward space")
plt.colorbar(lc, label="optimisation step")
plt.tight_layout()
plt.show()

View File

@@ -44,7 +44,7 @@ class ESM3GuidedDecoding:
self._use_batch_executor = True
else:
raise ValueError(
"client must be an instance of ESM3 or ESM3ForgeInferenceClient"
f"client must be an instance of ESM3 or ESM3ForgeInferenceClient. Got {type(client)}"
)
self.client = client

View File

@@ -1,15 +1,10 @@
import asyncio
import base64
import inspect
import pickle
from concurrent.futures import ThreadPoolExecutor
from contextvars import ContextVar
from functools import wraps
from typing import Any, Literal, Sequence, cast
from typing import Any, Sequence
import torch
from attr import asdict
from tenacity import retry, retry_if_result, stop_after_attempt, wait_exponential
from esm.sdk.api import (
MSA,
@@ -29,17 +24,11 @@ from esm.sdk.api import (
SamplingTrackConfig,
)
from esm.sdk.base_forge_client import _BaseForgeInferenceClient
from esm.sdk.retry import retry_decorator
from esm.utils.constants.api import MIMETYPE_ES_PICKLE
from esm.utils.misc import (
deserialize_tensors,
maybe_list,
maybe_tensor,
)
from esm.utils.sampling import validate_sampling_config
from esm.utils.misc import deserialize_tensors, maybe_list, maybe_tensor
from esm.utils.types import FunctionAnnotation
skip_retries_var = ContextVar("skip_retries", default=False)
def _list_to_function_annotations(l) -> list[FunctionAnnotation] | None:
if l is None or len(l) <= 0:
@@ -47,21 +36,17 @@ def _list_to_function_annotations(l) -> list[FunctionAnnotation] | None:
return [FunctionAnnotation(*t) for t in l]
def retry_if_specific_error(exception):
"""
We only retry on specific errors.
Currently we retry for 502 (bad gateway) and 429 (rate limit)
"""
return isinstance(exception, ESMProteinError) and exception.error_code in {
429,
502,
504,
}
def _maybe_logits(data: dict[str, Any], track: str, return_bytes: bool = False):
ret = data.get("logits", {}).get(track, None)
# TODO(s22chan): just return this when removing return_bytes
return ret if ret is None or not return_bytes else maybe_tensor(ret)
def log_retry_attempt(retry_state):
print(
f"Retrying... Attempt {retry_state.attempt_number} after {retry_state.next_action.sleep}s due to: {retry_state.outcome.result()}"
def _maybe_b64_decode(obj, return_bytes: bool):
return (
deserialize_tensors(base64.b64decode(obj))
if return_bytes and isinstance(obj, str)
else obj
)
@@ -77,45 +62,6 @@ def _validate_protein_tensor_input(input):
)
def retry_decorator(func):
"""
A static method that returns a retry decorator. This decorator uses the
instance's retry settings.
"""
@wraps(func)
async def async_wrapper(instance, *args, **kwargs):
if skip_retries_var.get():
return await func(instance, *args, **kwargs)
retry_decorator = retry(
retry=retry_if_result(retry_if_specific_error),
wait=wait_exponential(
multiplier=1, min=instance.min_retry_wait, max=instance.max_retry_wait
),
stop=stop_after_attempt(instance.max_retry_attempts),
before_sleep=log_retry_attempt,
)
# Apply the retry decorator to the function
return await retry_decorator(func)(instance, *args, **kwargs)
@wraps(func)
def wrapper(instance, *args, **kwargs):
if skip_retries_var.get():
return func(instance, *args, **kwargs)
retry_decorator = retry(
retry=retry_if_result(retry_if_specific_error),
wait=wait_exponential(
multiplier=1, min=instance.min_retry_wait, max=instance.max_retry_wait
),
stop=stop_after_attempt(instance.max_retry_attempts),
before_sleep=log_retry_attempt,
)
# Apply the retry decorator to the function
return retry_decorator(func)(instance, *args, **kwargs)
return async_wrapper if inspect.iscoroutinefunction(func) else wrapper
class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
def __init__(
self,
@@ -147,13 +93,9 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
)
@staticmethod
def _process_fold_request(
sequence: str,
model_name: str | None,
):
def _process_fold_request(sequence: str, model_name: str | None):
request: dict[str, Any] = {"sequence": sequence}
request["model"] = model_name
return request
@@ -164,12 +106,15 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
sequence=sequence,
coordinates=maybe_tensor(data["coordinates"], convert_none_to_nan=True),
ptm=maybe_tensor(data.get("ptm", None)),
plddt=maybe_tensor(data.get("plddt", None)),
plddt=maybe_tensor(data.get("plddt", None), convert_none_to_nan=True),
)
@staticmethod
def process_inverse_fold_request(
coordinates: torch.Tensor, config: InverseFoldingConfig, model_name: str | None
coordinates: torch.Tensor,
sequence: str | None,
config: InverseFoldingConfig,
model_name: str | None,
):
inverse_folding_config = {
"invalid_ids": config.invalid_ids,
@@ -178,6 +123,7 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
request = {
"coordinates": maybe_list(coordinates, convert_nan_to_none=True),
"inverse_folding_config": inverse_folding_config,
"sequence": sequence,
}
if model_name is not None:
request["model"] = model_name
@@ -222,13 +168,13 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
del potential_sequence_of_concern
request = self._process_fold_request(
sequence,
model_name if model_name is not None else self.model,
sequence, model_name if model_name is not None else self.model
)
# Intentionally not catching errors, so our higher level logic such as automatic
# batch runner gets a chance to handle different errors properly.
data = await self._async_post("fold", request)
try:
data = await self._async_post("fold", request)
except ESMProteinError as e:
return e
return self._process_fold_response(data, sequence)
@@ -249,16 +195,17 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
potential_sequence_of_concern: this parameter is largely deprecated
and ignored by the folding endpoint.
"""
del potential_sequence_of_concern
request = self._process_fold_request(
sequence,
model_name if model_name is not None else self.model,
sequence, model_name if model_name is not None else self.model
)
# Intentionally not catching errors, so our higher level logic such as automatic
# batch runner gets a chance to handle different errors properly.
data = self._post("fold", request)
try:
data = self._post("fold", request)
except ESMProteinError as e:
return e
return self._process_fold_response(data, sequence)
@@ -268,6 +215,7 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
coordinates: torch.Tensor,
config: InverseFoldingConfig,
potential_sequence_of_concern: bool,
sequence: str | None = None,
model_name: str | None = None,
) -> ESMProtein | ESMProteinError:
"""Generate protein sequence from its structure.
@@ -282,14 +230,18 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
model_name: Override the client level model name if needed.
"""
request = self.process_inverse_fold_request(
coordinates, config, model_name if model_name is not None else self.model
coordinates,
sequence,
config,
model_name if model_name is not None else self.model,
)
# Intentionally not catching errors, so our higher level logic such as automatic
# batch runner gets a chance to handle different errors properly.
data = await self._async_post(
"inverse_fold", request, potential_sequence_of_concern
)
try:
data = await self._async_post(
"inverse_fold", request, potential_sequence_of_concern
)
except ESMProteinError as e:
return e
return ESMProtein(sequence=data["sequence"])
@@ -299,6 +251,7 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
coordinates: torch.Tensor,
config: InverseFoldingConfig,
potential_sequence_of_concern: bool,
sequence: str | None = None,
model_name: str | None = None,
) -> ESMProtein | ESMProteinError:
"""Generate protein sequence from its structure.
@@ -313,12 +266,16 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
model_name: Override the client level model name if needed.
"""
request = self.process_inverse_fold_request(
coordinates, config, model_name if model_name is not None else self.model
coordinates,
sequence,
config,
model_name if model_name is not None else self.model,
)
# Intentionally not catching errors, so our higher level logic such as automatic
# batch runner gets a chance to handle different errors properly.
data = self._post("inverse_fold", request, potential_sequence_of_concern)
try:
data = self._post("inverse_fold", request, potential_sequence_of_concern)
except ESMProteinError as e:
return e
return ESMProtein(sequence=data["sequence"])
@@ -442,6 +399,8 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient, _BaseForgeInferenceClient):
def _process_forward_and_sample_request(
input: ESMProteinTensor, sampling_configuration: SamplingConfig, model_name: str
) -> dict[str, Any]:
from esm.utils.sampling import validate_sampling_config
_validate_protein_tensor_input(input)
validate_sampling_config(sampling_configuration, on_invalid="raise")
@@ -623,7 +582,9 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient, _BaseForgeInferenceClient):
"function": config.function,
"residue_annotations": config.residue_annotations,
"return_embeddings": config.return_embeddings,
"return_mean_embedding": config.return_mean_embedding,
"return_hidden_states": config.return_hidden_states,
"return_mean_hidden_states": config.return_mean_hidden_states,
"ith_hidden_layer": config.ith_hidden_layer,
}
@@ -634,34 +595,28 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient, _BaseForgeInferenceClient):
def _process_logits_response(
data: dict[str, Any], return_bytes: bool
) -> LogitsOutput:
def _maybe_logits(track: str):
ret = data.get("logits", {}).get(track, None)
# TODO(s22chan): just return this when removing return_bytes
return ret if ret is None or not return_bytes else maybe_tensor(ret)
def _maybe_b64_decode(obj):
return (
deserialize_tensors(base64.b64decode(obj))
if return_bytes and isinstance(obj, str)
else obj
)
logits = _maybe_b64_decode(data["logits"])
logits = _maybe_b64_decode(data["logits"], return_bytes)
data["logits"] = dict(logits) if logits is not None else logits
data["embeddings"] = _maybe_b64_decode(data["embeddings"])
data["hidden_states"] = _maybe_b64_decode(data["hidden_states"])
data["embeddings"] = _maybe_b64_decode(data["embeddings"], return_bytes)
data["hidden_states"] = _maybe_b64_decode(data["hidden_states"], return_bytes)
return LogitsOutput(
logits=ForwardTrackData(
sequence=_maybe_logits("sequence"),
structure=_maybe_logits("structure"),
secondary_structure=_maybe_logits("secondary_structure"),
sasa=_maybe_logits("sasa"),
function=_maybe_logits("function"),
sequence=_maybe_logits(data, "sequence", return_bytes),
structure=_maybe_logits(data, "structure", return_bytes),
secondary_structure=_maybe_logits(
data, "secondary_structure", return_bytes
),
sasa=_maybe_logits(data, "sasa", return_bytes),
function=_maybe_logits(data, "function", return_bytes),
),
embeddings=maybe_tensor(data["embeddings"]),
residue_annotation_logits=_maybe_logits("residue_annotation"),
mean_embedding=data["mean_embedding"],
residue_annotation_logits=_maybe_logits(
data, "residue_annotation", return_bytes
),
hidden_states=maybe_tensor(data["hidden_states"]),
mean_hidden_state=maybe_tensor(data["mean_hidden_state"]),
)
@retry_decorator
@@ -741,8 +696,8 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient, _BaseForgeInferenceClient):
try:
res = self.async_generate(protein, config)
return await res
except Exception as e:
return ESMProteinError(500, str(e))
except ESMProteinError as e:
return e
tasks = [
safe_generate(protein, config) for protein, config in zip(inputs, configs)
@@ -1009,6 +964,7 @@ class ESMCForgeInferenceClient(ESMCInferenceClient, _BaseForgeInferenceClient):
logits_config = {
"sequence": config.sequence,
"return_embeddings": config.return_embeddings,
"return_mean_embedding": config.return_mean_embedding,
"return_hidden_states": config.return_hidden_states,
"ith_hidden_layer": config.ith_hidden_layer,
}
@@ -1019,26 +975,17 @@ class ESMCForgeInferenceClient(ESMCInferenceClient, _BaseForgeInferenceClient):
def _process_logits_response(
data: dict[str, Any], return_bytes: bool
) -> LogitsOutput:
def _maybe_logits(track: str):
ret = data.get("logits", {}).get(track, None)
# TODO(s22chan): just return this when removing return_bytes
return ret if ret is None or not return_bytes else maybe_tensor(ret)
def _maybe_b64_decode(obj):
return (
deserialize_tensors(base64.b64decode(obj))
if return_bytes and isinstance(obj, str)
else obj
)
logits = _maybe_b64_decode(data["logits"])
logits = _maybe_b64_decode(data["logits"], return_bytes)
data["logits"] = dict(logits) if logits is not None else logits
data["embeddings"] = _maybe_b64_decode(data["embeddings"])
data["hidden_states"] = _maybe_b64_decode(data["hidden_states"])
data["embeddings"] = _maybe_b64_decode(data["embeddings"], return_bytes)
data["hidden_states"] = _maybe_b64_decode(data["hidden_states"], return_bytes)
output = LogitsOutput(
logits=ForwardTrackData(sequence=_maybe_logits("sequence")),
logits=ForwardTrackData(
sequence=_maybe_logits(data, "sequence", return_bytes)
),
embeddings=maybe_tensor(data["embeddings"]),
mean_embedding=data["mean_embedding"],
hidden_states=maybe_tensor(data["hidden_states"]),
)
return output

76
esm/sdk/retry.py Normal file
View File

@@ -0,0 +1,76 @@
import inspect
from contextvars import ContextVar
from functools import wraps
import httpx
from tenacity import (
retry,
retry_if_exception_type,
retry_if_result,
stop_after_attempt,
wait_incrementing,
)
from esm.sdk.api import ESMProteinError
skip_retries_var = ContextVar("skip_retries", default=False)
def retry_if_specific_error(exception):
"""
We only retry on specific errors.
"""
return isinstance(exception, ESMProteinError) and exception.error_code in {
429,
500,
502,
504,
500,
}
def log_retry_attempt(retry_state):
print(
f"Retrying... Attempt {retry_state.attempt_number} after {retry_state.next_action.sleep}s due to: {retry_state.outcome.result()}"
)
def retry_decorator(func):
"""
A static method that returns a retry decorator. This decorator uses the
instance's retry settings.
"""
@wraps(func)
async def async_wrapper(instance, *args, **kwargs):
if skip_retries_var.get():
return await func(instance, *args, **kwargs)
retry_decorator = retry(
retry=retry_if_result(retry_if_specific_error)
| retry_if_exception_type(httpx.ConnectTimeout), # ADDED
wait=wait_incrementing(
increment=1, start=instance.min_retry_wait, max=instance.max_retry_wait
),
stop=stop_after_attempt(instance.max_retry_attempts),
before_sleep=log_retry_attempt,
)
# Apply the retry decorator to the function
return await retry_decorator(func)(instance, *args, **kwargs)
@wraps(func)
def wrapper(instance, *args, **kwargs):
if skip_retries_var.get():
return func(instance, *args, **kwargs)
retry_decorator = retry(
retry=retry_if_result(retry_if_specific_error)
| retry_if_exception_type(httpx.ConnectTimeout), # ADDED
wait=wait_incrementing(
increment=1, start=instance.min_retry_wait, max=instance.max_retry_wait
),
stop=stop_after_attempt(instance.max_retry_attempts),
before_sleep=log_retry_attempt,
)
# Apply the retry decorator to the function
return retry_decorator(func)(instance, *args, **kwargs)
return async_wrapper if inspect.iscoroutinefunction(func) else wrapper

View File

@@ -22,7 +22,7 @@ class SequenceStructureSageMakerClient(SequenceStructureForgeInferenceClient):
self._boto3_client = boto3.client(service_name="sagemaker-runtime")
def _post(self, endpoint, request, potential_sequence_of_concern):
def _post(self, endpoint, request, potential_sequence_of_concern: bool = False):
request["potential_sequence_of_concern"] = potential_sequence_of_concern
request["model"] = request.get("model", None)
invocations_request = {

View File

@@ -1,10 +1,7 @@
from dataclasses import dataclass
from typing import Protocol
from esm.utils.constants.models import (
ESM3_OPEN_SMALL,
normalize_model_name,
)
from esm.utils.constants.models import ESM3_OPEN_SMALL, normalize_model_name
from .function_tokenizer import InterProQuantizedTokenizer
from .residue_tokenizer import ResidueAnnotationsTokenizer

View File

@@ -10,24 +10,12 @@ from esm.models.function_decoder import FunctionTokenDecoder
from esm.models.vqvae import StructureTokenDecoder
from esm.sdk.api import ESMProtein, ESMProteinTensor
from esm.tokenization import TokenizerCollectionProtocol
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer,
)
from esm.tokenization.residue_tokenizer import (
ResidueAnnotationsTokenizer,
)
from esm.tokenization.sasa_tokenizer import (
SASADiscretizingTokenizer,
)
from esm.tokenization.sequence_tokenizer import (
EsmSequenceTokenizer,
)
from esm.tokenization.ss_tokenizer import (
SecondaryStructureTokenizer,
)
from esm.tokenization.structure_tokenizer import (
StructureTokenizer,
)
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
from esm.tokenization.residue_tokenizer import ResidueAnnotationsTokenizer
from esm.tokenization.sasa_tokenizer import SASADiscretizingTokenizer
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
from esm.tokenization.ss_tokenizer import SecondaryStructureTokenizer
from esm.tokenization.structure_tokenizer import StructureTokenizer
from esm.tokenization.tokenizer_base import EsmTokenizerBase
from esm.utils.constants import api as api_constants
from esm.utils.constants import esm3 as C
@@ -251,6 +239,7 @@ def assemble_message(headers: Mapping[str, str], response: Response) -> dict[str
content_type = headers.get("Content-Type", "application/json")
if content_type == api_constants.MIMETYPE_ES_PICKLE:
return pickle.loads(response.content)
elif content_type == "application/json":
elif "application/json" in content_type:
# Can handle something like "application/json; charset=utf-8"
return response.json()
raise ValueError(f"Unknown Content-Type: {content_type}")

View File

@@ -7,26 +7,13 @@ from esm.models.vqvae import StructureTokenEncoder
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer as EsmFunctionTokenizer,
)
from esm.tokenization.residue_tokenizer import (
ResidueAnnotationsTokenizer,
)
from esm.tokenization.sasa_tokenizer import (
SASADiscretizingTokenizer,
)
from esm.tokenization.sequence_tokenizer import (
EsmSequenceTokenizer,
)
from esm.tokenization.ss_tokenizer import (
SecondaryStructureTokenizer,
)
from esm.tokenization.structure_tokenizer import (
StructureTokenizer,
)
from esm.tokenization.residue_tokenizer import ResidueAnnotationsTokenizer
from esm.tokenization.sasa_tokenizer import SASADiscretizingTokenizer
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
from esm.tokenization.ss_tokenizer import SecondaryStructureTokenizer
from esm.tokenization.structure_tokenizer import StructureTokenizer
from esm.utils.constants import esm3 as C
from esm.utils.function.encode_decode import (
encode_function_annotations,
)
from esm.utils.function.encode_decode import encode_function_annotations
from esm.utils.structure.protein_chain import ProteinChain
from esm.utils.types import FunctionAnnotation
@@ -165,8 +152,6 @@ def tokenize_function_annotations(
return function_tokens, residue_annotation_tokens
# Tokenized Defaults
def get_default_sequence_tokens(
sequence_length: int, sequence_tokenizer: EsmSequenceTokenizer
@@ -242,5 +227,3 @@ def get_default_residue_annotation_tokens(
residue_annotation_tokens[0] = residue_annotation_tokenizer.bos_token_id
residue_annotation_tokens[-1] = residue_annotation_tokenizer.eos_token_id
return residue_annotation_tokens

View File

@@ -7,10 +7,7 @@ from typing import Any, Callable, List
from tqdm import tqdm
from esm.sdk.api import ESMProteinError
from esm.sdk.forge import (
retry_if_specific_error,
skip_retries_var,
)
from esm.sdk.retry import retry_if_specific_error, skip_retries_var
TQDM_BAR_FORMAT = (
"{desc:<12}{percentage:3.0f}%|{bar:24}| {n_fmt}/{total_fmt} "
@@ -25,7 +22,7 @@ class AIMDRateLimiter:
self,
initial_concurrency: int = 32,
min_concurrency: int = 1,
max_concurrency: int = 512,
max_concurrency: int = 64,
step_up: int = 1,
):
self.concurrency = initial_concurrency
@@ -56,8 +53,10 @@ class ForgeBatchExecutor:
"""
def __init__(
self, max_attempts: int = 10, max_workers: int = 512, show_progress: bool = True
self, max_attempts: int = 10, max_workers: int = 64, show_progress: bool = True
):
if max_workers > 64:
raise ValueError("max_workers must be less than 64")
self.rate_limiter = AIMDRateLimiter(max_concurrency=max_workers)
self.max_attempts = max_attempts
self.show_progress = show_progress

View File

@@ -3,16 +3,9 @@ from typing import Sequence
import torch
from esm.models.function_decoder import (
FunctionTokenDecoder,
merge_annotations,
)
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer,
)
from esm.tokenization.residue_tokenizer import (
ResidueAnnotationsTokenizer,
)
from esm.models.function_decoder import FunctionTokenDecoder, merge_annotations
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
from esm.tokenization.residue_tokenizer import ResidueAnnotationsTokenizer
from esm.utils.constants import esm3 as C
from esm.utils.types import FunctionAnnotation

View File

@@ -65,7 +65,7 @@ class LSHTokenized:
hyperplanes: dict[str, np.ndarray] = { # type: ignore
str(i): table.hyperplanes for i, table in enumerate(self.tables)
}
np.savez(filepath, **hyperplanes)
np.savez(filepath, **hyperplanes) # type: ignore
def __call__(self, array):
tokens = np.stack([table(array) for table in self.tables], 1)

View File

@@ -19,13 +19,8 @@ from esm.sdk.api import (
SamplingConfig,
SamplingTrackConfig,
)
from esm.tokenization import (
EsmTokenizerBase,
TokenizerCollectionProtocol,
)
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer,
)
from esm.tokenization import EsmTokenizerBase, TokenizerCollectionProtocol
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
from esm.utils.constants import esm3 as C
from esm.utils.misc import stack_variable_length_tensors
from esm.utils.noise_schedules import NOISE_SCHEDULE_REGISTRY

View File

@@ -1,143 +0,0 @@
import pytest
import torch
from evolutionaryscale.models.esm3v2 import Esm3v2
from esm.sdk.api import (
ESMProtein,
ESMProteinTensor,
GenerationConfig,
)
from evolutionaryscale.utils.env import ModelName
from evolutionaryscale.utils.remote_inference.api_v1 import (
ESM3RemoteModelInferenceClient,
)
from projects.forge.inference.utils.model import _load_esm_model
@pytest.fixture()
def esm3_remote_inference_client():
model = _load_esm_model(
ModelName.ESM3_TINY_DEV, distributed_model=False, load_function_decoder=False
)
assert isinstance(model, Esm3v2)
client = ESM3RemoteModelInferenceClient(
model,
tokenizers=model.tokenizers,
device=torch.device("cuda"),
enable_batched_runner=False,
)
return client
@pytest.mark.gpu
def test_chain_break_tokens(esm3_remote_inference_client):
tokenizer = esm3_remote_inference_client.tokenizers.sequence
# 3 separate chains with 2 chainbreak tokens.
sequence_with_chain_breaks = torch.tensor(
[
tokenizer.bos_token_id,
20,
20,
20,
20,
tokenizer.chain_break_token_id,
21,
21,
21,
tokenizer.chain_break_token_id,
22,
22,
22,
tokenizer.eos_token_id,
]
)
protein = esm3_remote_inference_client.generate(
ESMProteinTensor(sequence=sequence_with_chain_breaks),
# There are 10 tokens that actually need to be sampled.
GenerationConfig(track="structure", num_steps=10),
)
assert isinstance(protein, ESMProteinTensor)
assert protein.structure is not None
@pytest.mark.gpu
def test_num_decoding_steps_more_than_mask_tokens(esm3_remote_inference_client):
protein = esm3_remote_inference_client.generate(
esm3_remote_inference_client.encode(
ESMProtein(sequence="CDEFG")
), # sequence of 5.
GenerationConfig(track="structure", num_steps=10), # use 10 decoding steps.
)
# Client should handle over-specification of decoding steps.
# TODO: This should be a warning.
assert isinstance(protein, ESMProteinTensor)
assert protein.structure is not None
@pytest.mark.gpu
def test_num_decoding_steps_more_than_mask_tokens_batched(esm3_remote_inference_client):
protein_list = esm3_remote_inference_client.batch_generate(
inputs=[
esm3_remote_inference_client.encode(ESMProtein(sequence="CDEFG")),
esm3_remote_inference_client.encode(ESMProtein(sequence="ABCDEFG")),
esm3_remote_inference_client.encode(ESMProtein(sequence="AB__EFG")),
],
configs=[
GenerationConfig(track="structure", num_steps=10),
GenerationConfig(track="structure", num_steps=3),
GenerationConfig(track="sequence", num_steps=20),
],
)
# Client should handle over-specification of decoding steps.
# TODO: This should be a warning.
assert isinstance(protein_list[0], ESMProteinTensor)
assert protein_list[0].structure is not None
assert isinstance(protein_list[1], ESMProteinTensor)
assert protein_list[1].structure is not None
assert isinstance(protein_list[2], ESMProteinTensor)
assert protein_list[2].sequence is not None
@pytest.mark.gpu
def test_encode_chainbreak_token(esm3_remote_inference_client):
protein = esm3_remote_inference_client.encode(ESMProtein(sequence="MSTNP|KPQKK"))
assert isinstance(protein, ESMProteinTensor)
assert protein.sequence is not None
assert (
protein.sequence[6]
== esm3_remote_inference_client.tokenizers.sequence.chain_break_token_id
)
@pytest.mark.gpu
def test_generation_with_chainbreak_token(esm3_remote_inference_client):
chainbreak_sequence = torch.tensor(
[
esm3_remote_inference_client.tokenizers.sequence.bos_token_id,
20,
8,
11,
17,
14,
esm3_remote_inference_client.tokenizers.sequence.chain_break_token_id,
15,
14,
16,
15,
15,
esm3_remote_inference_client.tokenizers.sequence.eos_token_id,
]
)
protein = esm3_remote_inference_client.generate(
ESMProteinTensor(sequence=chainbreak_sequence),
GenerationConfig(track="structure", num_steps=1),
)
# Can't specify more decoding steps than masks available.
assert isinstance(protein, ESMProteinTensor)
assert protein.structure is not None
assert (
protein.structure[6]
== esm3_remote_inference_client.tokenizers.structure.chain_break_token_id
)

View File

@@ -33,15 +33,15 @@ def slice_python_object_as_numpy(
>>> slice_python_object_as_numpy(obj, np.arange(5) < 3)
[1, 2, 3]
"""
if isinstance(idx, int):
idx = [idx]
if np.isscalar(idx):
idx = [int(idx)] # type: ignore
if isinstance(idx, np.ndarray) and idx.dtype == bool:
sliced_obj = [obj[i] for i in np.where(idx)[0]]
elif isinstance(idx, slice):
sliced_obj = obj[idx]
else:
sliced_obj = [obj[i] for i in idx]
sliced_obj = [obj[i] for i in idx] # type: ignore
match obj, sliced_obj:
case str(), list():
@@ -156,6 +156,37 @@ def stack_variable_length_tensors(
return array
def binpack(
tensor: torch.Tensor, sequence_id: torch.Tensor | None, pad_value: int | float
):
"""
Args:
tensor (Tensor): [B, L, ...]
Returns:
Tensor: [B_binpacked, L_binpacked, ...]
"""
if sequence_id is None:
return tensor
num_sequences = sequence_id.max(dim=-1).values + 1
dims = sequence_id.shape + tensor.shape[2:]
output_tensor = torch.full(
dims, fill_value=pad_value, dtype=tensor.dtype, device=tensor.device
)
idx = 0
for batch_idx, (batch_seqid, batch_num_sequences) in enumerate(
zip(sequence_id, num_sequences)
):
for seqid in range(batch_num_sequences):
mask = batch_seqid == seqid
output_tensor[batch_idx, mask] = tensor[idx, : mask.sum()]
idx += 1
return output_tensor
def unbinpack(
tensor: torch.Tensor, sequence_id: torch.Tensor | None, pad_value: int | float
):
@@ -280,9 +311,18 @@ def maybe_list(x, convert_nan_to_none: bool = False) -> list | None:
return None
if not convert_nan_to_none:
return x.tolist()
nan_mask = torch.isnan(x)
np_arr = x.cpu().numpy().astype(object)
np_arr[nan_mask.cpu().numpy()] = None
# Handle both torch.tensor and np.ndarray input.
if isinstance(x, torch.Tensor):
nan_mask = torch.isnan(x).cpu().numpy()
np_arr = x.cpu().numpy().astype(object)
elif isinstance(x, np.ndarray):
nan_mask = np.isnan(x)
np_arr = x.astype(object)
else:
raise TypeError("maybe_list can only work with torch.tensor or np.ndarray.")
np_arr[nan_mask] = None
return np_arr.tolist()
@@ -313,7 +353,6 @@ def get_chainbreak_boundaries_from_sequence(sequence: Sequence[str]) -> np.ndarr
return chain_boundaries
# TODO(return_bytes): remove when retiring return_bytes on SageMaker
def deserialize_tensors(b: bytes) -> Any:
buf = BytesIO(zstd.ZSTD_uncompress(b))
d = torch.load(buf, map_location="cpu", weights_only=False)

File diff suppressed because it is too large Load Diff

View File

@@ -5,18 +5,9 @@ import attr
import torch
import torch.nn.functional as F
from esm.sdk.api import (
ESMProteinTensor,
SamplingConfig,
SamplingTrackConfig,
)
from esm.tokenization import (
TokenizerCollectionProtocol,
get_invalid_tokenizer_ids,
)
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer,
)
from esm.sdk.api import ESMProteinTensor, SamplingConfig, SamplingTrackConfig
from esm.tokenization import TokenizerCollectionProtocol, get_invalid_tokenizer_ids
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
from esm.utils.constants.esm3 import (
MAX_RESIDUE_ANNOTATIONS,
SASA_DISCRETIZATION_BOUNDARIES,

View File

@@ -1,16 +1,17 @@
from __future__ import annotations
import typing as T
from abc import ABC
from dataclasses import dataclass
import torch
from torch.nn import functional as F
from typing_extensions import Self
from esm.utils.misc import fp32_autocast_context
@T.runtime_checkable
class Rotation(T.Protocol):
class Rotation(ABC):
@classmethod
def identity(cls, shape: tuple[int, ...], **tensor_kwargs) -> Self: ...
@@ -34,6 +35,8 @@ class Rotation(T.Protocol):
def as_matrix(self) -> RotationMatrix: ...
def as_quat(self, normalize: bool = False) -> RotationQuat: ...
def compose(self, other: Self) -> Self:
# To be safe, we force users to explicitly convert between rotation types.
...
@@ -87,7 +90,8 @@ class RotationMatrix(Rotation):
assert rots.shape[-1] == 3
assert rots.shape[-2] == 3
# Force full precision
self._rots = rots.to(torch.float32)
rots = rots.to(torch.float32)
self._rots = rots
@classmethod
def identity(cls, shape, **tensor_kwargs):
@@ -98,9 +102,7 @@ class RotationMatrix(Rotation):
@classmethod
def random(cls, shape, **tensor_kwargs):
v1 = torch.randn((*shape, 3), **tensor_kwargs)
v2 = torch.randn((*shape, 3), **tensor_kwargs)
return cls(_graham_schmidt(v1, v2))
return RotationQuat.random(shape, **tensor_kwargs).as_matrix()
def __getitem__(self, idx: T.Any) -> RotationMatrix:
indices = (idx,) if isinstance(idx, int) or idx is None else tuple(idx)
@@ -113,6 +115,49 @@ class RotationMatrix(Rotation):
def as_matrix(self) -> RotationMatrix:
return self
def as_quat(self, normalize: bool = False) -> RotationQuat:
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
self._rots.flatten(-2), dim=-1
)
q_abs = _sqrt_subgradient(
torch.stack(
[
1.0 + m00 + m11 + m22,
1.0 + m00 - m11 - m22,
1.0 - m00 + m11 - m22,
1.0 - m00 - m11 + m22,
],
dim=-1,
)
)
# we produce the desired quaternion multiplied by each of r, i, j, k
quat_by_rijk = torch.stack(
[
x
for lst in [
[q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01],
[m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20],
[m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21],
[m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2],
]
for x in lst
],
dim=-1,
).unflatten(-1, (4, 4))
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
# the candidate won't be picked.
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
# forall i; we pick the best-conditioned one (with the largest denominator)
# We manually implement one_hot so torch.compile works
one_hot = torch.zeros_like(q_abs, dtype=torch.bool)
one_hot.scatter_(-1, q_abs.argmax(dim=-1, keepdim=True), True)
quat = quat_candidates[one_hot, :].reshape(q_abs.shape)
return RotationQuat(quat)
def compose(self, other: RotationMatrix) -> RotationMatrix:
with fp32_autocast_context(self._rots.device.type):
return RotationMatrix(self._rots @ other._rots)
@@ -147,6 +192,81 @@ class RotationMatrix(Rotation):
return RotationMatrix(_graham_schmidt(x_axis, xy_plane, eps))
class RotationQuat(Rotation):
def __init__(self, quats: torch.Tensor, normalized=False):
assert quats.shape[-1] == 4
self._normalized = normalized
# Force float32 as well
if normalized:
self._quats = F.normalize(quats.to(torch.float32), dim=-1)
self._quats = self._quats.where(self._quats[..., :1] >= 0, -self._quats)
else:
self._quats = quats.to(torch.float32)
@classmethod
def identity(cls, shape, **tensor_kwargs):
q = torch.ones((*shape, 4), **tensor_kwargs)
mult = torch.tensor([1, 0, 0, 0], device=q.device)
return RotationQuat(q * mult)
@classmethod
def random(cls, shape, **tensor_kwargs):
quat = torch.randn((*shape, 4), **tensor_kwargs)
return RotationQuat(quat, normalized=True)
def __getitem__(self, idx: T.Any) -> RotationQuat:
indices = (idx,) if isinstance(idx, int) or idx is None else tuple(idx)
return RotationQuat(self._quats[indices + (slice(None),)])
@property
def shape(self) -> torch.Size:
return self._quats.shape[:-1]
def compose(self, other: RotationQuat) -> RotationQuat:
with fp32_autocast_context(self._quats.device.type):
return RotationQuat(_quat_mult(self._quats, other._quats))
def convert_compose(self, other: Rotation):
return self.compose(other.as_quat())
def as_matrix(self) -> RotationMatrix:
q = self.normalized().tensor
r, i, j, k = torch.unbind(q, -1)
two_s = 2.0 / torch.linalg.norm(q, dim=-1)
o = torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
-1,
)
return RotationMatrix(o.reshape(q.shape[:-1] + (3, 3)))
def as_quat(self, normalize: bool = False) -> RotationQuat:
return self
def apply(self, p: torch.Tensor) -> torch.Tensor:
return _quat_rotation(self.normalized()._quats, p)
def invert(self) -> RotationQuat:
return RotationQuat(_quat_invert(self._quats))
@property
def tensor(self) -> torch.Tensor:
return self._quats
def normalized(self) -> RotationQuat:
return self if self._normalized else RotationQuat(self._quats, normalized=True)
@dataclass(frozen=True)
class Affine3D:
trans: torch.Tensor
@@ -222,6 +342,9 @@ class Affine3D:
def as_matrix(self):
return Affine3D(trans=self.trans, rot=self.rot.as_matrix())
def as_quat(self, normalize: bool = False):
return Affine3D(trans=self.trans, rot=self.rot.as_quat(normalize))
def compose(self, other: "Affine3D", autoconvert: bool = False):
rot = self.rot
new_rot = (rot.convert_compose if autoconvert else rot.compose)(other.rot)
@@ -271,6 +394,13 @@ class Affine3D:
# Assume tensor 4x4 for backward compat with alphafold
trans = t[..., :3, 3]
rot = RotationMatrix(t[..., :3, :3])
case 6:
# Assume quaternion representation with real part = 1
trans = t[..., -3:]
rot = RotationQuat(F.pad(t[..., :3], (1, 0), value=1))
case 7:
trans = t[..., -3:]
rot = RotationQuat(t[..., :4])
case 12:
trans = t[..., -3:]
rot = RotationMatrix(t[..., :-3].unflatten(-1, (3, 3)))
@@ -305,6 +435,62 @@ class Affine3D:
return Affine3D.from_tensor(torch.cat([x.tensor for x in affines], dim=dim))
def _quat_mult(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""
Multiply two quaternions.
Usual torch rules for broadcasting apply.
Args:
a: Quaternions as tensor of shape (..., 4), real part first.
b: Quaternions as tensor of shape (..., 4), real part first.
Returns:
The product of a and b, a tensor of quaternions shape (..., 4).
"""
aw, ax, ay, az = torch.unbind(a, -1)
bw, bx, by, bz = torch.unbind(b, -1)
ow = aw * bw - ax * bx - ay * by - az * bz
ox = aw * bx + ax * bw + ay * bz - az * by
oy = aw * by - ax * bz + ay * bw + az * bx
oz = aw * bz + ax * by - ay * bx + az * bw
return torch.stack((ow, ox, oy, oz), -1)
def _quat_rotation(q: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
"""
Rotates p by quaternion q. Usual torch rules for broadcasting apply.
Args:
q: Quaternions as tensor of shape (..., 4), real part first.
p: Points as tensor of shape (..., 3)
Returns:
The rotated version of p, of shape (..., 3)
"""
aw, ax, ay, az = torch.unbind(q, -1)
bx, by, bz = torch.unbind(p, -1)
# fmt: off
ow = - ax * bx - ay * by - az * bz
ox = aw * bx + ay * bz - az * by
oy = aw * by - ax * bz + az * bx
oz = aw * bz + ax * by - ay * bx
# fmt: on
q_mul_pts = torch.stack((ow, ox, oy, oz), -1)
return _quat_mult(q_mul_pts, _quat_invert(q))[..., 1:]
def _quat_invert(q: torch.Tensor):
return q * torch.tensor([1, -1, -1, -1], device=q.device)
def _sqrt_subgradient(x: torch.Tensor) -> torch.Tensor:
# Returns torch.sqrt(torch.max(0, x)) but with a zero subgradient where x is 0.
ret = torch.zeros_like(x)
positive_mask = x > 0
ret[positive_mask] = torch.sqrt(x[positive_mask])
return ret
def _graham_schmidt(x_axis: torch.Tensor, xy_plane: torch.Tensor, eps: float = 1e-12):
# A low eps here is necessary for good stability!
with fp32_autocast_context(x_axis.device.type):

View File

@@ -6,17 +6,21 @@ from typing import Any, ClassVar, Protocol, TypeVar
import numpy as np
import torch
from esm.utils.structure.protein_structure import (
compute_affine_and_rmsd,
)
from esm.utils.structure.protein_structure import compute_affine_and_rmsd
class Alignable(Protocol):
atom37_positions: np.ndarray
atom37_mask: np.ndarray
# Trick to detect whether an object is a dataclass
__dataclass_fields__: ClassVar[dict[str, Field[Any]]]
@property
def atom37_positions(self) -> np.ndarray: # type: ignore
pass
@property
def atom37_mask(self) -> np.ndarray: # type: ignore
pass
def __len__(self) -> int: ...

View File

@@ -0,0 +1,15 @@
import numpy as np
from esm.utils.structure.protein_structure import index_by_atom_name
class AtomIndexer:
def __init__(self, structure, property: str, dim: int):
self.structure = structure
self.property = property
self.dim = dim
def __getitem__(self, atom_names: str | list[str]) -> np.ndarray:
return index_by_atom_name(
getattr(self.structure, self.property), atom_names, self.dim
)

View File

@@ -1,19 +1,148 @@
import numpy as np
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor
from torch.amp import autocast # type: ignore
from esm.utils import residue_constants
from esm.utils.misc import unbinpack
from esm.utils.misc import binpack, unbinpack
from esm.utils.structure.protein_structure import (
compute_alignment_tensors,
compute_gdt_ts_no_alignment,
compute_rmsd_no_alignment,
)
def contact_precision(
predictions: Tensor,
targets: Tensor,
src_lengths: Tensor | None = None,
minsep: int = 6,
maxsep: int | None = None,
override_length: int | None = None, # for casp
):
"""Computes contact precisions.
For protein contact prediction, precision is measured for the top (L/K) highest confidence predictions,
with L being the length of the protein sequence and K generally being equal to 1 or 5.
K = 5 measures the predictions of the very highest confidence contacts, while K = 1 is a more general measure
over all relatively high confidence predictions.
Since there are roughly ~L true contacts in a protein, this is a reasonable cutoff.
Args:
predictions (Tensor): Tensor of probabilities of size (B, L, L)
targets (Tensor): Tensor of true contacts of size (B, L, L)
src_lengths (Tensor, optional): Lengths of each sample in the batch, if using variable lengths.
If not provided, inferred from the size of the predictions.
minsep (int): Minimum separation distance to consider. We often want to measure contacts at a
certain range. Typical ranges are short [6, 12), medium [12, 24), and long [24, inf).
maxsep (int, optional): Used in conjunction with minsep to specify a contact range. If not provided uses
assumes no maximum range
override_length (int, optional): Used for casp evaluation where sometimes the "true" length is not
the same as the length of the input. Kept for posterity, we probably don't need this argument.
"""
if predictions.dim() == 2:
predictions = predictions.unsqueeze(0)
if targets.dim() == 2:
targets = targets.unsqueeze(0)
# Check sizes
if predictions.size() != targets.size():
raise ValueError(
f"Size mismatch. Received predictions of size {predictions.size()}, "
f"targets of size {targets.size()}"
)
device = predictions.device
batch_size, seqlen, _ = predictions.size()
# Step 1) Construct a mask of size [B, L, L] to mask invalid contacts
seqlen_range = torch.arange(seqlen, device=device)
sep = seqlen_range.unsqueeze(0) - seqlen_range.unsqueeze(1)
sep = sep.unsqueeze(0)
# Mask contacts that are closer than minsep
valid_mask = sep >= minsep
# Mask contacts where target is negative (padding or unknown)
valid_mask = valid_mask & (targets >= 0) # negative targets are invalid
# Mask contacts that are farther than maxsep, if provided
if maxsep is not None:
valid_mask &= sep < maxsep
if src_lengths is not None:
# If the lengths of the individual sequences are provided, mask positions
# that are farther than the end of the sequence.
valid = seqlen_range.unsqueeze(0) < src_lengths.unsqueeze(1)
valid_mask &= valid.unsqueeze(1) & valid.unsqueeze(2)
else:
src_lengths = torch.full([batch_size], seqlen, device=device, dtype=torch.long)
# Fill in the logit tensor with -inf for all invalid positions
predictions = predictions.masked_fill(~valid_mask, float("-inf"))
# Step 2) Select the top half of the prediction (should be symmetric)
x_ind, y_ind = np.triu_indices(seqlen, minsep)
predictions_upper = predictions[:, x_ind, y_ind]
targets_upper = targets[:, x_ind, y_ind]
# Step 3) Select the topk values in each batch where k = L (length of sequence)
topk = seqlen if override_length is None else max(seqlen, override_length)
# Indices are the indices into the predictions corresponding to the most confident predictions
indices = predictions_upper.argsort(dim=-1, descending=True)[:, :topk]
# topk_targets are the target values corresponding to the above indices
topk_targets = targets_upper[torch.arange(batch_size).unsqueeze(1), indices]
if topk_targets.size(1) < topk:
# If there aren't enough targets, pad to the output.
topk_targets = F.pad(topk_targets, [0, topk - topk_targets.size(1)])
# Step 4) Sum the accuracy at of the top-i predictions for i in 1, L
# topk_targets => 1/0 true vs. false contact, sorted by confidence of prediction
# cmumulative sum => Number of correct answers for the top-i predictions.
cumulative_dist = topk_targets.type_as(predictions).cumsum(-1)
# Step 5) Find the gather indices. This should be P@(L / K) for varous values of K
# The values will differ for each batch.
gather_lengths = src_lengths.unsqueeze(1)
if override_length is not None:
gather_lengths = override_length * torch.ones_like(
gather_lengths, device=device
)
# This gets you (0.1 * L, 0.2 * L, 0.3 * L, etc.)
gather_indices = (
(torch.arange(0.1, 1.1, 0.1, device=device).unsqueeze(0) * gather_lengths).type(
torch.long
)
- 1
).clamp_min(0)
# Step 6) Gather the results and divide by the number of guesses to get the precision.
binned_cumulative_dist = cumulative_dist.gather(1, gather_indices)
binned_precisions = binned_cumulative_dist / (gather_indices + 1).type_as(
binned_cumulative_dist
)
# Select specific P@L/k. pl5 is index 1 b/c that corresponds to L * 0.2 in
# gather_indices above
pl5 = binned_precisions[:, 1]
# pl2 = binned_precisions[:, 4]
pl = binned_precisions[:, 9]
# AUC is the integral wrt K of P@L/K for K in range(1, L)
auc = binned_precisions.mean(-1)
return {"AUC": auc, "P@L": pl, "P@L5": pl5}
def compute_lddt(
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
cutoff: float = 15.0,
pairwise_all_atom_mask: torch.Tensor | None = None,
cutoff: float | torch.Tensor = 15.0,
eps: float = 1e-10,
per_residue: bool = True,
sequence_id: torch.Tensor | None = None,
@@ -29,7 +158,8 @@ def compute_lddt(
all_atom_pred_pos (Tensor[float], [(Nstates x) B x (L * Natoms x) 3]): Tensor of predicted positions
all_atom_positions (Tensor[float], [B x (L * Natoms x) 3]): Tensor of true positions
all_atom_mask (Tensor[float], [B x (L * Natoms)]): Tensor of masks, indicating whether an atom exists.
cutoff (float): Max distance to score lddt over.
pairwise_all_atom_mask (Tensor[float], [B x (L * Natoms x L * Natoms)], optional): Tensor of masks, indicating whether a pair of atoms should be considered in the LDDT calculation.
cutoff (float): Max distance to score lddt over. This can either be a float, or a tensor of shape [B, L, L] to allow for per-residue cutoffs, e.g. if you want to use a different cutoff for nucleic acids.
per_residue (bool): Whether to return per-residue or full-protein lddt.
sequence_id (Tensor, optional): Sequence id tensor for binpacking. NOTE: only supported for lddt_ca calculations, not when Natoms is passed!
@@ -40,7 +170,7 @@ def compute_lddt(
else:
Tensor[float], [(Nstates x) B]
"""
n = all_atom_mask.shape[-2]
all_atom_mask = all_atom_mask[..., None] # add a dimension for broadcasting
dmat_true = torch.sqrt(
eps
+ torch.sum(
@@ -57,22 +187,56 @@ def compute_lddt(
dim=-1,
)
)
dists_to_score = (
(dmat_true < cutoff)
* all_atom_mask
* rearrange(all_atom_mask, "... a b -> ... b a")
* (1.0 - torch.eye(n, device=all_atom_mask.device))
)
mask = all_atom_mask * rearrange(all_atom_mask, "... a b -> ... b a")
if pairwise_all_atom_mask is not None:
mask = mask * pairwise_all_atom_mask
if sequence_id is not None:
# TODO(roshan): This will work for lddt_ca, but not for regular lddt
# TODO: This will work for lddt_ca, but not for regular lddt
# Problem is that regular lddt has natoms * nres scores, so would need to repeat this mask by natoms
# Leaving for now because it won't fail silently so should be ook.
seqid_mask = sequence_id[..., None] == sequence_id[..., None, :]
dists_to_score = dists_to_score * seqid_mask.type_as(dists_to_score)
mask = mask * seqid_mask.type_as(mask)
return compute_lddt_from_dmat(
dmat_pred, dmat_true, mask, cutoff=cutoff, eps=eps, per_residue=per_residue
)
def compute_lddt_from_dmat(
dmat_pred: torch.Tensor,
dmat_true: torch.Tensor,
pairwise_mask: torch.Tensor,
cutoff: float | torch.Tensor = 15.0,
eps: float = 1e-10,
per_residue: bool = True,
):
"""
Compute LDDT from pre-computed distance matrices.
This is useful when you want to compute LDDT with multiple different masks or cutoffs, e.g. for different molecule types (protein, nucleic acid, etc.).
Args:
dmat_pred (Tensor[float], [B x L x L]): Predicted distance matrix
dmat_true (Tensor[float], [B x L x L]): True distance matrix
pairwise_mask (Tensor[float], [B x L x L]): Pairwise mask indicating which pairs of atoms to consider
cutoff (float): Max distance to score lddt over. This can either be a float, or a tensor of shape [B, L, L] to allow for per-residue cutoffs, e.g. if you want to use a different cutoff for nucleic acids.
per_residue (bool): Whether to return per-residue or full-protein lddt.
Returns:
LDDT Tensor:
if per_residue:
Tensor[float], [B x L]
else:
Tensor[float], [B]
"""
n = dmat_true.size(-1)
dists_to_score = (
(dmat_true < cutoff)
* pairwise_mask
* (1.0 - torch.eye(n, device=dmat_true.device))
)
dist_l1 = torch.abs(dmat_true - dmat_pred)
score = (
(dist_l1 < 0.5).type(dist_l1.dtype)
+ (dist_l1 < 1.0).type(dist_l1.dtype)
@@ -84,7 +248,6 @@ def compute_lddt(
dims = (-1,) if per_residue else (-2, -1)
norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims))
score = norm * (eps + torch.sum(dists_to_score * score, dim=dims))
return score
@@ -114,6 +277,56 @@ def compute_lddt_ca(
)
# NOTE(roshan): no_grad required for stack_variable_length_tensors apparently... let's revisit if we want to backprop
@torch.no_grad()
@autocast("cuda", enabled=False)
def compute_rmsd(
mobile: torch.Tensor,
target: torch.Tensor,
atom_exists_mask: torch.Tensor | None = None,
sequence_id: torch.Tensor | None = None,
reduction: str = "batch",
):
"""
Compute RMSD between two batches of structures with support for masking invalid atoms using PyTorch.
Args:
- mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3)
- target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3)
- atom_exists_mask (torch.Tensor, optional): Mask for Whether an atom exists of shape (B, N)
- sequence_id (torch.Tensor, optional): Sequence id tensor for binpacking.
- reduction (str): One of "batch", "per_sample", "per_residue".
Returns:
If reduction == "batch":
(torch.Tensor): 0-dim, Average Root Mean Square Deviation between the structures for each batch
If reduction == "per_sample":
(torch.Tensor): (B,)-dim, Root Mean Square Deviation between the structures for each batch
If reduction == "per_residue":
(torch.Tensor): (B, N)-dim, Root Mean Square Deviation between the structures for residue in the batch
"""
(centered_mobile, _, centered_target, _, rotation_matrix, num_valid_atoms) = (
compute_alignment_tensors(
mobile=mobile,
target=target,
atom_exists_mask=atom_exists_mask,
sequence_id=sequence_id,
)
)
# Apply transformation to centered structure
rotated_mobile = torch.matmul(centered_mobile, rotation_matrix)
# Compute rmsd for centered structures
rmsd = compute_rmsd_no_alignment(
rotated_mobile, centered_target, num_valid_atoms, reduction=reduction
)
if reduction == "per_residue" and sequence_id is not None:
rmsd = binpack(rmsd, sequence_id, pad_value=0)
return rmsd
def compute_gdt_ts(
mobile: torch.Tensor,
target: torch.Tensor,

View File

@@ -0,0 +1,469 @@
from __future__ import annotations
import functools
import io
import os
from dataclasses import dataclass
from datetime import datetime
from typing import Union
import biotite.structure as bs
import biotite.structure.io.pdbx as pdbx
from esm.utils import residue_constants
# Define PathOrBuffer for the opensource version
PathOrBuffer = Union[str, os.PathLike, io.StringIO]
class NoProteinError(Exception):
pass
@dataclass
class Residue:
residue_number: int | None = None
insertion_code: str = ""
hetflag: bool = False
@dataclass
class MmcifHeader:
release_date: datetime | None = None
resolution: float | None = None
structure_method: str = "UNKNOWN"
class MmcifWrapper:
def __init__(self, id: str | None = None):
self.id: str = id or ""
self.raw: pdbx.CIFFile | None = None
self.structure: bs.AtomArray
self.header: MmcifHeader = MmcifHeader()
self.entities: dict[int, list[str]] = {}
self.chain_to_seqres: dict[str, str] = {}
self.seqres_to_structure: dict[str, dict[int, Residue]] = {}
@classmethod
def read(cls, path: PathOrBuffer, id: str | None = None) -> MmcifWrapper:
obj = cls(id=id)
obj._load(path)
return obj
def _load(self, path: PathOrBuffer, fileid: str | None = None):
"""Load mmCIF data from file."""
self.raw = pdbx.CIFFile.read(path)
self._parse_structure()
self._parse_header()
self._parse_entities()
self._parse_sequences()
def _parse_structure(self):
"""Parse the atomic structure from mmCIF."""
try:
structure = pdbx.get_structure(self.raw, model=1)
if structure is None or not isinstance(structure, bs.AtomArray):
raise NoProteinError("No structure found in mmCIF file")
if len(structure) == 0:
raise NoProteinError("Empty structure in mmCIF file")
self.structure = structure
except Exception as e:
raise ValueError(f"Failed to parse structure: {e}")
def _parse_header(self):
"""Parse header information from mmCIF."""
if not self.raw:
return
try:
# Get the first (and usually only) block
block = self.raw.block
# Parse release date
if "pdbx_database_status" in block:
status_cat = block["pdbx_database_status"]
if "recvd_initial_deposition_date" in status_cat:
date_str = status_cat["recvd_initial_deposition_date"].as_item()
if date_str and date_str != "?":
try:
self.header.release_date = datetime.strptime(
date_str, "%Y-%m-%d"
)
except ValueError:
pass
# Parse resolution
if "refine" in block:
refine_cat = block["refine"]
if "ls_d_res_high" in refine_cat:
res_str = refine_cat["ls_d_res_high"].as_item()
if res_str and res_str != "?":
try:
self.header.resolution = float(res_str)
except ValueError:
pass
# Parse structure method
if "exptl" in block:
exptl_cat = block["exptl"]
if "method" in exptl_cat:
method = exptl_cat["method"].as_item()
if method and method != "?":
self.header.structure_method = method.upper()
except Exception:
# If parsing fails, keep default values
pass
def _parse_entities(self):
"""Parse entity information and map to chains."""
if not self.raw:
return
try:
block = self.raw.block
# Parse entity information
if "entity" in block:
entity_cat = block["entity"]
entity_ids = entity_cat["id"].as_array(str)
entity_types = entity_cat["type"].as_array(str)
# Initialize entities dict with all entities (not just polymers)
for i, (entity_id, entity_type) in enumerate(
zip(entity_ids, entity_types)
):
self.entities[int(entity_id)] = []
# Map polymer chains to entities using entity_poly
if "entity_poly" in block:
poly_cat = block["entity_poly"]
entity_ids = poly_cat["entity_id"].as_array(str)
chain_lists = poly_cat["pdbx_strand_id"].as_array(str)
for entity_id, chain_list in zip(entity_ids, chain_lists):
entity_id = int(entity_id)
# Chain list is comma-separated
chains = [c.strip() for c in chain_list.split(",") if c.strip()]
if entity_id in self.entities:
self.entities[entity_id] = chains
# Map non-polymer chains using struct_asym for entities not covered by entity_poly
if "struct_asym" in block:
asym_cat = block["struct_asym"]
asym_ids = asym_cat["id"].as_array(str)
entity_ids = asym_cat["entity_id"].as_array(str)
for asym_id, entity_id in zip(asym_ids, entity_ids):
entity_id = int(entity_id)
# Only add if entity exists but has no chains yet (non-polymer entities)
if entity_id in self.entities and not self.entities[entity_id]:
self.entities[entity_id].append(asym_id)
except Exception:
# If parsing fails, try to infer from structure
if (
self.structure
and hasattr(self.structure, "chain_id")
and self.structure.chain_id is not None
and hasattr(self.structure.chain_id, "__iter__")
):
chain_ids = list(set(self.structure.chain_id))
self.entities = {1: chain_ids}
def _parse_sequences(self):
"""Parse sequence information from mmCIF."""
if not self.raw:
return
block = self.raw.block
# Parse polymer sequences
if "entity_poly" in block:
poly_cat = block["entity_poly"]
entity_ids = poly_cat["entity_id"].as_array(str)
sequences = poly_cat["pdbx_seq_one_letter_code_can"].as_array(str)
chain_lists = poly_cat["pdbx_strand_id"].as_array(str)
for entity_id, sequence, chain_list in zip(
entity_ids, sequences, chain_lists
):
# Clean up sequence (remove whitespace and newlines)
clean_seq = "".join(sequence.split())
chains = [c.strip() for c in chain_list.split(",") if c.strip()]
for chain_id in chains:
self.chain_to_seqres[chain_id] = clean_seq
# Parse sequence to structure mapping
if "pdbx_poly_seq_scheme" in block:
seq_cat = block["pdbx_poly_seq_scheme"]
asym_ids = seq_cat["asym_id"].as_array(str) # Internal chain IDs
seq_positions = seq_cat["seq_id"].as_array(str)
auth_seq_nums = seq_cat["auth_seq_num"].as_array(str)
ins_codes = (
seq_cat["pdb_ins_code"].as_array(str)
if "pdb_ins_code" in seq_cat
else [""] * len(asym_ids)
)
hetflags = (
seq_cat["hetflag"].as_array(str)
if "hetflag" in seq_cat
else ["N"] * len(asym_ids)
)
# Get author chain IDs if available
auth_chain_ids = (
seq_cat["pdb_strand_id"].as_array(str)
if "pdb_strand_id" in seq_cat
else asym_ids # Fallback to internal IDs
)
# Build mapping from internal chain ID to author chain ID
asym_to_auth_mapping = {}
for asym_id, auth_id in zip(asym_ids, auth_chain_ids):
asym_to_auth_mapping[asym_id] = auth_id
# Group by internal chain ID first, then map to author chain ID
chain_data = {}
for asym_id, seq_pos, auth_seq, ins_code, hetflag in zip(
asym_ids, seq_positions, auth_seq_nums, ins_codes, hetflags
):
if asym_id not in chain_data:
chain_data[asym_id] = {}
try:
seq_index = int(seq_pos) - 1 # Convert to 0-based indexing
res_num = int(auth_seq) if auth_seq != "?" else None
except ValueError:
continue
if res_num is not None:
# Convert mmCIF "." and "?" to empty string
clean_ins_code = "" if ins_code in [".", "?"] else ins_code
else:
clean_ins_code = ""
res_num = None
is_het = hetflag.upper() == "Y" # type: ignore
chain_data[asym_id][seq_index] = Residue(
residue_number=res_num,
insertion_code=clean_ins_code, # type: ignore
hetflag=is_het,
)
# Handle cases where multiple residues have the same auth_seq_num
# by adjusting residue numbers to be unique within each chain
for asym_id, residue_data in chain_data.items():
# Check if there are duplicate residue numbers in this chain
positions_with_same_num = {}
for seq_idx, res_at_pos in residue_data.items():
if res_at_pos.residue_number is not None:
res_num = res_at_pos.residue_number
if res_num not in positions_with_same_num:
positions_with_same_num[res_num] = []
positions_with_same_num[res_num].append(seq_idx)
# Fix duplicate residue numbers by making them sequential
for res_num, seq_indices in positions_with_same_num.items():
if len(seq_indices) > 1:
# Multiple residues have the same residue number
# Make them sequential starting from the original number
seq_indices.sort() # Ensure consistent ordering
for i, seq_idx in enumerate(seq_indices):
original_pos = residue_data[seq_idx]
new_pos = Residue(
residue_number=res_num + i,
insertion_code=original_pos.insertion_code,
hetflag=original_pos.hetflag,
)
residue_data[seq_idx] = new_pos
# Create ordered mappings using author chain IDs
for asym_id in chain_data:
auth_chain_id = asym_to_auth_mapping.get(asym_id, asym_id)
if auth_chain_id in self.chain_to_seqres:
seq_len = len(self.chain_to_seqres[auth_chain_id])
ordered_mapping = {}
for i in range(seq_len):
if i in chain_data[asym_id]:
ordered_mapping[i] = chain_data[asym_id][i]
else:
# Missing residue - no structure coordinates
ordered_mapping[i] = Residue(
residue_number=None, insertion_code="", hetflag=False
)
self.seqres_to_structure[auth_chain_id] = ordered_mapping
else:
# Handle case where auth_chain_id is not in chain_to_seqres
# This can happen if the chain is not a polymer or if there's a parsing issue
# Create a basic mapping based on the chain_data
if chain_data[asym_id]:
# Sort by sequence index to create ordered mapping
sorted_indices = sorted(chain_data[asym_id].keys())
ordered_mapping = {}
for i, seq_idx in enumerate(sorted_indices):
ordered_mapping[i] = chain_data[asym_id][seq_idx]
self.seqres_to_structure[auth_chain_id] = ordered_mapping
# Ensure all chains have complete mappings
for chain_id in self.chain_to_seqres:
if chain_id not in self.seqres_to_structure:
seq_len = len(self.chain_to_seqres[chain_id])
self.seqres_to_structure[chain_id] = {
i: Residue(residue_number=None, insertion_code="", hetflag=False)
for i in range(seq_len)
}
else:
# Fill in any missing indices
seq_len = len(self.chain_to_seqres[chain_id])
mapping = self.seqres_to_structure[chain_id]
for i in range(seq_len):
if i not in mapping:
mapping[i] = Residue(
residue_number=None, insertion_code="", hetflag=False
)
# Fallback: create basic mappings from structure for missing chains
if (
self.structure
and hasattr(self.structure, "chain_id")
and self.structure.chain_id is not None
and hasattr(self.structure.chain_id, "__iter__")
):
for chain_id in set(self.structure.chain_id):
if chain_id not in self.seqres_to_structure:
chain_structure = self.structure[
self.structure.chain_id == chain_id
]
if (
hasattr(chain_structure, "res_id")
and chain_structure.res_id is not None
and hasattr(chain_structure.res_id, "__iter__")
):
residue_ids = list(set(chain_structure.res_id))
residue_ids.sort()
self.seqres_to_structure[chain_id] = {
i: Residue(
residue_number=res_id, insertion_code="", hetflag=False
)
for i, res_id in enumerate(residue_ids)
}
def _parse_nonpoly_from_mmcif(self) -> dict[tuple, bs.AtomArray]:
"""Parse non-polymer coordinates from mmCIF block data."""
nonpoly_coords = {}
# Get non-polymer entities from the mmCIF block
assert self.raw is not None
block = self.raw.block
nonpoly_entities = set()
# Find non-polymer entities
if "entity" in block:
entity_cat = block["entity"]
entity_ids = entity_cat["id"].as_array(str)
entity_types = entity_cat["type"].as_array(str)
for entity_id, entity_type in zip(entity_ids, entity_types):
if entity_type.upper() in ["NON-POLYMER", "WATER", "BRANCHED"]:
nonpoly_entities.add(entity_id)
# Map entities to chains for non-polymers
entity_to_chains = {}
if "pdbx_entity_nonpoly" in block:
nonpoly_cat = block["pdbx_entity_nonpoly"]
entity_ids = nonpoly_cat["entity_id"].as_array(str)
comp_ids = nonpoly_cat["comp_id"].as_array(str)
for entity_id, comp_id in zip(entity_ids, comp_ids):
if entity_id in nonpoly_entities:
entity_to_chains[entity_id] = comp_id
# Get atom site information for non-polymers
if "atom_site" in block:
atom_cat = block["atom_site"]
atom_chain_ids = atom_cat["label_asym_id"].as_array(str)
atom_entity_ids = atom_cat["label_entity_id"].as_array(str)
atom_comp_ids = atom_cat["label_comp_id"].as_array(str)
# Group non-polymer atoms by entity and chain
nonpoly_atom_groups = {}
for i, (chain_id, entity_id, comp_id) in enumerate(
zip(atom_chain_ids, atom_entity_ids, atom_comp_ids)
):
if entity_id in nonpoly_entities:
key = (comp_id, chain_id)
if key not in nonpoly_atom_groups:
nonpoly_atom_groups[key] = []
nonpoly_atom_groups[key].append(i)
# Extract coordinates for each non-polymer group
for (comp_id, chain_id), atom_indices in nonpoly_atom_groups.items():
# Match atoms by comparing chain_id and residue name
structure_mask = (self.structure.chain_id == chain_id) & (
self.structure.res_name == comp_id
)
if structure_mask.any():
nonpoly_array = self.structure[structure_mask]
if (
isinstance(nonpoly_array, (bs.AtomArray, bs.AtomArrayStack))
and len(nonpoly_array) > 0
):
nonpoly_coords[(comp_id, chain_id)] = nonpoly_array
return nonpoly_coords
def _parse_nonpoly_fallback(self) -> dict[tuple, bs.AtomArray]:
"""Fallback method to extract heteroatoms directly from structure."""
nonpoly_coords = {}
if not (self.structure and hasattr(self.structure, "chain_id")):
return nonpoly_coords
# Create set of standard residues from residue_constants
standard_residues = set(residue_constants.resnames[:-1]) # Exclude 'UNK'
standard_residues.update({"A", "C", "G", "T", "U"}) # Add nucleic acids
if hasattr(self.structure, "chain_id") and self.structure.chain_id is not None:
for chain_id in set(self.structure.chain_id):
chain_structure = self.structure[self.structure.chain_id == chain_id]
# Find non-standard residues
if (
hasattr(chain_structure, "res_name")
and chain_structure.res_name is not None
and hasattr(chain_structure.res_name, "__iter__")
):
for res_name in set(chain_structure.res_name):
if res_name not in standard_residues:
res_mask = (chain_structure.chain_id == chain_id) & (
chain_structure.res_name == res_name
)
if res_mask.any() and isinstance(
chain_structure, (bs.AtomArray, bs.AtomArrayStack)
):
nonpoly_array = chain_structure[res_mask]
nonpoly_coords[(res_name, chain_id)] = nonpoly_array
return nonpoly_coords
@functools.cached_property
def non_polymer_coords(self) -> dict[tuple, bs.AtomArray]:
"""
Extract non-polymer coordinates (ligands, cofactors, etc.) from mmCIF structure.
Returns a dictionary mapping (nonpolymer_info, chain_id) tuples to AtomArrays.
"""
if not self.structure or not self.raw:
return {}
try:
return self._parse_nonpoly_from_mmcif()
except Exception:
return self._parse_nonpoly_fallback()

View File

@@ -1,10 +1,11 @@
from __future__ import annotations
import io
import warnings
from dataclasses import asdict, dataclass, replace
from functools import cached_property
from pathlib import Path
from typing import Sequence, TypeVar, Union
from typing import Any, Mapping, Sequence
import biotite.structure as bs
import brotli
@@ -12,51 +13,36 @@ import msgpack
import msgpack_numpy
import numpy as np
import torch
from Bio.Data import PDBData
from biotite.application.dssp import DsspApp
from biotite.database import rcsb
from biotite.structure.io.pdb import PDBFile
from cloudpathlib import CloudPath
from scipy.spatial import ConvexHull
from scipy.spatial.distance import pdist, squareform
from torch import Tensor
from biotite.structure.io.pdbx import CIFCategory, CIFColumn, CIFData, CIFFile
from biotite.structure.io.pdbx import set_structure as set_structure_pdbx
from scipy.spatial import ConvexHull, KDTree
from scipy.spatial.distance import cdist, pdist, squareform
from esm.utils import residue_constants as RC
from esm.utils.constants import esm3 as C
from esm.utils import residue_constants
from esm.utils.misc import slice_python_object_as_numpy
from esm.utils.structure.affine3d import Affine3D
from esm.utils.structure.aligner import Aligner
from esm.utils.structure.metrics import compute_lddt_ca
from esm.utils.structure.atom_indexer import AtomIndexer
from esm.utils.structure.metrics import compute_gdt_ts, compute_lddt_ca
from esm.utils.structure.mmcif_parsing import MmcifWrapper, Residue
from esm.utils.structure.normalize_coordinates import (
apply_frame_to_coords,
get_protein_normalization_frame,
normalize_coordinates,
)
from esm.utils.structure.protein_structure import index_by_atom_name
from esm.utils.types import PathOrBuffer
msgpack_numpy.patch()
CHAIN_ID_CONST = "A"
ArrayOrTensor = TypeVar("ArrayOrTensor", np.ndarray, Tensor)
PathLike = Union[str, Path, CloudPath]
PathOrBuffer = Union[PathLike, io.StringIO]
def index_by_atom_name(
atom37: ArrayOrTensor, atom_names: str | list[str], dim: int = -2
) -> ArrayOrTensor:
squeeze = False
if isinstance(atom_names, str):
atom_names = [atom_names]
squeeze = True
indices = [RC.atom_order[atom_name] for atom_name in atom_names]
dim = dim % atom37.ndim
index = tuple(slice(None) if dim != i else indices for i in range(atom37.ndim))
result = atom37[index] # type: ignore
if squeeze:
result = result.squeeze(dim)
return result
def _num_non_null_residues(seqres_to_structure_chain: Mapping[int, Residue]) -> int:
return sum(
residue.residue_number is not None
for residue in seqres_to_structure_chain.values()
)
def infer_CB(C, N, Ca, L: float = 1.522, A: float = 1.927, D: float = -2.143):
@@ -78,34 +64,84 @@ def infer_CB(C, N, Ca, L: float = 1.522, A: float = 1.927, D: float = -2.143):
return Ca + sum([m * d for m, d in zip(m, d)])
class AtomIndexer:
def __init__(self, structure: ProteinChain, property: str, dim: int):
self.structure = structure
self.property = property
self.dim = dim
def chain_to_ndarray(
atom_array: bs.AtomArray, mmcif: MmcifWrapper, chain_id: str, is_predicted=False
):
entity_id = None
for entity, chains in mmcif.entities.items():
if chain_id in chains:
entity_id = entity
num_res = len(mmcif.chain_to_seqres[chain_id])
sequence = mmcif.chain_to_seqres[chain_id]
def __getitem__(self, atom_names: str | list[str]) -> np.ndarray:
return index_by_atom_name(
getattr(self.structure, self.property), atom_names, self.dim
)
atom_positions = np.full([num_res, residue_constants.atom_type_num, 3], np.nan)
atom_mask = np.full([num_res, residue_constants.atom_type_num], False, dtype=bool)
residue_index = np.full([num_res], -1, dtype=np.int64)
insertion_code = np.full([num_res], "", dtype="<U4")
confidence = np.ones([num_res], dtype=np.float32)
for res_index in range(num_res):
chain = atom_array[atom_array.chain_id == chain_id]
assert isinstance(chain, bs.AtomArray)
res_at_position = mmcif.seqres_to_structure[chain_id][res_index]
if res_at_position.residue_number is None:
continue
residue_index[res_index] = res_at_position.residue_number
insertion_code[res_index] = res_at_position.insertion_code
res = chain[
(chain.res_id == res_at_position.residue_number)
& (chain.ins_code == res_at_position.insertion_code)
& (chain.hetero == res_at_position.hetflag)
]
assert isinstance(res, bs.AtomArray)
# Atom level features
for atom in res:
atom_name = atom.atom_name
if atom_name == "SE" and atom.res_name == "MSE":
# Put the coords of the selenium atom in the sulphur column
atom_name = "SD"
if atom_name in residue_constants.atom_order:
atom_positions[res_index, residue_constants.atom_order[atom_name]] = (
atom.coord
)
atom_mask[res_index, residue_constants.atom_order[atom_name]] = True
if is_predicted and atom_name == "CA":
confidence[res_index] = atom.b_factor
assert all(sequence), "Some residue name was not specified correctly"
return (
sequence,
atom_positions,
atom_mask,
residue_index,
insertion_code,
confidence,
entity_id,
)
@dataclass
@dataclass(frozen=True)
class ProteinChain:
"""Dataclass with atom37 representation of a single protein chain."""
id: str
sequence: str
chain_id: str # author chain id
chain_id: str # author chain id - mutable
entity_id: int | None
residue_index: np.ndarray
insertion_code: np.ndarray
atom37_positions: np.ndarray
atom37_mask: np.ndarray
confidence: np.ndarray
mmcif: MmcifWrapper | None = None
def __post_init__(self):
self.atom37_mask = self.atom37_mask.astype(bool)
assert self.atom37_mask.dtype == bool, self.atom37_mask.dtype
assert self.atom37_positions.shape[0] == len(self.sequence), (
self.atom37_positions.shape,
len(self.sequence),
@@ -152,10 +188,10 @@ class ProteinChain:
chain_id="A" if self.chain_id is None else self.chain_id,
res_id=res_idx,
ins_code=ins_code,
res_name=RC.restype_1to3.get(res_name, "UNK"),
res_name=residue_constants.restype_1to3.get(res_name, "UNK"),
hetero=False,
atom_name=RC.atom_types[i],
element=RC.atom_types[i][0],
atom_name=residue_constants.atom_types[i],
element=residue_constants.atom_types[i][0],
b_factor=conf,
)
atoms.append(atom)
@@ -182,18 +218,20 @@ class ProteinChain:
# hard coded to as we currently only support single chain structures
chain_id=CHAIN_ID_CONST,
res_id=res_idx + 1,
res_name=RC.restype_1to3.get(res_name, "UNK"),
res_name=residue_constants.restype_1to3.get(res_name, "UNK"),
hetero=False,
atom_name=RC.atom_types[i],
element=RC.atom_types[i][0],
atom_name=residue_constants.atom_types[i],
element=residue_constants.atom_types[i][0],
b_factor=conf,
)
atoms.append(atom)
return bs.array(atoms)
def __getitem__(self, idx: int | list[int] | slice | np.ndarray):
def __getitem__(self, idx: int | list[int] | slice | np.ndarray | torch.Tensor):
if isinstance(idx, int):
idx = [idx]
if isinstance(idx, torch.Tensor):
idx = idx.cpu().numpy()
sequence = slice_python_object_as_numpy(self.sequence, idx)
return replace(
@@ -216,17 +254,6 @@ class ProteinChain:
np.fill_diagonal(contacts, -1)
return contacts
def to_structure_encoder_inputs(
self, should_normalize_coordinates: bool = True
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
coords = torch.tensor(self.atom37_positions, dtype=torch.float32)
plddt = torch.tensor(self.confidence, dtype=torch.float32)
residue_index = torch.tensor(self.residue_index, dtype=torch.long)
if should_normalize_coordinates:
coords = normalize_coordinates(coords)
return coords.unsqueeze(0), plddt.unsqueeze(0), residue_index.unsqueeze(0)
def to_pdb(self, path: PathOrBuffer, include_insertions: bool = True):
"""Dssp works better w/o insertions."""
f = PDBFile()
@@ -242,11 +269,77 @@ class ProteinChain:
buf.seek(0)
return buf.read()
def to_mmcif(self, path: PathOrBuffer):
f = CIFFile()
set_structure_pdbx(f, self.atom_array, data_block=self.id)
# incantations molstar needs to render pLDDT / confidence onto
# the structure with "alphafold-view"
f.block["ma_qa_metric"] = CIFCategory(
name="ma_qa_metric",
columns={
"id": CIFColumn(data=CIFData(array=np.array([1, 2]), dtype=np.int64)),
"mode": CIFColumn(
data=CIFData(array=np.array(["global", "local"]), dtype=np.str_)
),
"name": CIFColumn(
data=CIFData(array=np.array(["pLDDT", "pLDDT"]), dtype=np.str_)
),
},
)
# table is a duplicate of data already in the atom array, but
# needed by molstar to render pLDDT / confidence
resid_pldd_table = {
# hard coded to as we currently only support single chain structures
"label_asym_id": CIFColumn(
data=CIFData(
array=[CHAIN_ID_CONST] * len(self.residue_index), dtype=np.str_
)
),
"label_comp_id": CIFColumn(
data=CIFData(
array=[
residue_constants.restype_1to3.get(c, "UNK")
for c in self.sequence
],
dtype=np.str_,
)
),
"label_seq_id": CIFColumn(
data=CIFData(array=self.residue_index, dtype=np.int64)
),
"ordinal_id": CIFColumn(
data=CIFData(array=self.residue_index, dtype=np.int64)
),
# hard coded to show these are all local plDDT values
"metric_id": CIFColumn(
data=CIFData(array=["2"] * len(self.residue_index), dtype=np.str_)
),
"metric_value": CIFColumn(
data=CIFData(array=self.confidence, dtype=np.float32)
),
# hard coded to show there are the initial version, there are no revisions
"model_id": CIFColumn(
data=CIFData(array=["1"] * len(self.residue_index), dtype=np.str_)
),
}
f.block["ma_qa_metric_local"] = CIFCategory(
name="ma_qa_metric_local", columns=resid_pldd_table
)
f.write(path)
def to_mmcif_string(self) -> str:
buf = io.StringIO()
self.to_mmcif(buf)
buf.seek(0)
return buf.read()
def state_dict(self, backbone_only=False, json_serializable=False):
"""This state dict is optimized for storage, so it turns things to fp16 whenever
possible. Note that we also only support int32 residue indices, I'm hoping we don't
need more than 2**32 residues..."""
dct = {k: v for k, v in asdict(self).items()}
dct = {k: v for k, v in asdict(self).items() if k not in ["mmcif"]}
if backbone_only:
dct["atom37_mask"][:, 3:] = False
dct["atom37_positions"] = dct["atom37_positions"][dct["atom37_mask"]]
@@ -265,7 +358,11 @@ class ProteinChain:
return dct
def to_blob(self, backbone_only=False) -> bytes:
return brotli.compress(msgpack.dumps(self.state_dict(backbone_only)))
return brotli.compress(msgpack.dumps(self.state_dict(backbone_only)), quality=5)
@classmethod
def from_open_source(cls, pc: ProteinChain):
return cls(**vars(pc))
@classmethod
def from_state_dict(cls, dct):
@@ -280,11 +377,11 @@ class ProteinChain:
k: (v.astype(np.float32) if k in ["atom37_positions", "confidence"] else v)
for k, v in dct.items()
}
return cls(**dct)
return cls(**dct, mmcif=None)
@classmethod
def from_blob(cls, input: Path | str | io.BytesIO | bytes):
"""NOTE: blob + sparse coding + brotli + fp16 reduces memory
"""NOTE(@zlin): blob + sparse coding + brotli + fp16 reduces memory
of chains from 52G/1M chains to 20G/1M chains, I think this is a good first
shot at compressing and dumping chains to disk. I'm sure there's better ways."""
match input:
@@ -296,24 +393,105 @@ class ProteinChain:
bytes = input
return cls.from_state_dict(msgpack.loads(brotli.decompress(bytes)))
def dssp(self):
dssp = DsspApp.annotate_sse(self.atom_array_no_insertions)
full_dssp = np.full(len(self.sequence), "X", dtype="<U1")
full_dssp[self.atom37_mask.any(-1)] = dssp
return full_dssp
def sasa(self):
def sasa(self, by_residue: bool = True):
arr = self.atom_array_no_insertions
sasa_per_atom = bs.sasa(arr) # type: ignore
# Sum per-atom SASA into residue "bins", with np.bincount.
if by_residue:
# Sum per-atom SASA into residue "bins", with np.bincount.
assert arr.res_id is not None
# NOTE(rverkuil): arr.res_id is 1-indexed, but np.bincount returns a sum for bin 0, so we strip.
# NOTE(aderry): We compute only for residues with coordinates, return NaN otherwise.
num_trailing_residues = len(self) - arr.res_id.max()
sasa_per_residue = np.concatenate(
[
np.bincount(arr.res_id, weights=sasa_per_atom)[1:],
np.zeros(num_trailing_residues),
]
)
sasa_per_residue[~self.atom37_mask.any(-1)] = np.nan
assert len(sasa_per_residue) == len(self)
return sasa_per_residue
return sasa_per_atom
def sap_score(self, aggregation: str = "atom") -> np.ndarray:
"""Computes per-atom SAP score.
Can optionally aggregate by residue (by averaging over atoms. NOTE: this returns values only for residues that have coordinates!)
or full-protein (sum of SAP score for atoms with SAP > 0, as in Lauer et al. 2011)."""
sap_radius = 5.0
arr = self.atom_array_no_insertions
# asserts to avoid type errors
assert arr.res_id is not None
assert np.array_equal(
np.sort(np.unique(arr.res_id)), np.arange(1, arr.res_id.max() + 1)
), "SASA calculation expected contiguous res_ids in range(1, len(chain)+1)"
# NOTE: arr.res_id is 1-indexed, but np.bincount returns a sum for bin 0, so we strip.
sasa_per_residue = np.bincount(arr.res_id, weights=sasa_per_atom)[1:]
assert len(sasa_per_residue) == len(self)
return sasa_per_residue
assert arr.res_name is not None
assert arr.atom_name is not None
assert arr.coord is not None
# compute SASA and residue-specific properties
sasa_per_atom = self.sasa(by_residue=False)
resid_to_resname = dict(zip(arr.res_id, arr.res_name))
max_side_chain_asa = np.full(len(self), np.nan)
res_hydrophobicity = np.full(len(self), np.nan)
resolved_res_mask = self.atom37_mask.any(-1)
num_trailing_residues = len(self) - arr.res_id.max()
max_side_chain_asa[resolved_res_mask] = np.array(
[
residue_constants.side_chain_asa[resid_to_resname[i]]
for i in np.unique(arr.res_id)
]
)
res_hydrophobicity[resolved_res_mask] = np.array(
[
residue_constants.hydrophobicity[resid_to_resname[i]]
for i in np.unique(arr.res_id)
]
)
assert len(max_side_chain_asa) == len(self)
assert len(res_hydrophobicity) == len(self)
# compute SAP score
is_side_chain = ~bs.filter_peptide_backbone(arr)
sasa_per_atom[is_side_chain] = 0
kdtree = KDTree(arr.coord)
neighbors = kdtree.query_ball_tree(kdtree, sap_radius, p=2.0)
sap_by_atom = np.zeros_like(sasa_per_atom)
for i, nn_list in enumerate(neighbors):
saa_nn = np.zeros_like(sasa_per_atom)
saa_nn[nn_list] = sasa_per_atom[nn_list]
sasa_within_r = np.concatenate(
[
np.bincount(arr.res_id, weights=saa_nn)[1:],
np.zeros(num_trailing_residues),
]
)
sap = np.nansum((sasa_within_r / max_side_chain_asa) * res_hydrophobicity)
sap_by_atom[i] = sap
match aggregation:
case "atom":
return sap_by_atom
case "residue":
sap_by_residue = np.concatenate(
[
np.bincount(arr.res_id, weights=sap_by_atom)[1:],
np.zeros(num_trailing_residues),
]
) / (
np.concatenate(
[np.bincount(arr.res_id)[1:], np.zeros(num_trailing_residues)]
)
+ 1e-8
)
sap_by_residue[~resolved_res_mask] = np.nan
assert len(sap_by_residue) == len(self)
return sap_by_residue
case "protein":
return sum(sap_by_atom[sap_by_atom > 0]) # pyright: ignore[reportReturnType]
case _:
raise ValueError(
f"Invalid aggregation method: {aggregation}. Must be one of 'atom', 'residue', or 'protein'"
)
def globularity(self) -> float:
# Computes globularity using total volumes divided by MVEE.
@@ -325,10 +503,10 @@ class ProteinChain:
# NOTE(@zeming): due to the approximation we make here, that atoms never overlap, you might get >1 globularity
mask = self.atom37_mask.any(-1)
points = self.atom37_positions[self.atom37_mask]
sequence = [aa for aa, m in zip(self.sequence, mask) if m]
sequence = [aa for aa, m in zip(self.sequence, mask) if m] # type: ignore
A, _ = self._mvee(points, tol=1e-3)
mvee_volume = (4 * np.pi) / (3 * np.sqrt(np.linalg.det(A)))
volume = sum(RC.amino_acid_volumes[x] for x in sequence)
volume = sum(residue_constants.amino_acid_volumes[x] for x in sequence)
ratio = volume / mvee_volume
# The paper says you must compare the ellipsoidal profile with T, a measurement of
@@ -388,6 +566,10 @@ class ProteinChain:
return A, c
def radius_of_gyration(self):
arr = self.atom_array_no_insertions
return bs.gyration_radius(arr)
def align(
self,
target: ProteinChain,
@@ -462,8 +644,8 @@ class ProteinChain:
target_inds: list[int] | np.ndarray | None = None,
**kwargs,
) -> float | np.ndarray:
"""Compute the LDDT between this protein chain and another.
NOTE: LDDT IS NOT SYMMETRIC. The call should always be prediction.lddt_ca(native).
"""Compute the LDDT between this protein chain and another. NOTE: LDDT IS NOT SYMMETRIC.
The call should always be prediction.lddt_ca(native).
Arguments:
native (ProteinChain): The ground truth protein chain
@@ -482,6 +664,171 @@ class ProteinChain:
)
return float(lddt) if lddt.numel() == 1 else lddt.numpy().flatten()
def gdt_ts(
self,
target: ProteinChain,
mobile_inds: list[int] | np.ndarray | None = None,
target_inds: list[int] | np.ndarray | None = None,
**kwargs,
) -> float | np.ndarray:
"""Compute the GDT_TS between this protein chain and another.
Arguments:
target (ProteinChain): The other protein chain to compare to.
mobile_inds (list[int], np.ndarray, optional): The indices of the mobile atoms to align. These are NOT residue indices
target_inds (list[int], np.ndarray, optional): The indices of the target atoms to align. These are NOT residue indices
Returns:
float: The GDT_TS score between the two protein chains.
"""
gdt_ts = compute_gdt_ts(
mobile=torch.tensor(
index_by_atom_name(self.atom37_positions[mobile_inds], "CA"),
dtype=torch.float32,
).unsqueeze(0),
target=torch.tensor(
index_by_atom_name(target.atom37_positions[target_inds], "CA"),
dtype=torch.float32,
).unsqueeze(0),
atom_exists_mask=torch.tensor(
index_by_atom_name(self.atom37_mask[mobile_inds], "CA", dim=-1)
& index_by_atom_name(target.atom37_mask[target_inds], "CA", dim=-1)
).unsqueeze(0),
**kwargs,
)
return float(gdt_ts) if gdt_ts.numel() == 1 else gdt_ts.numpy().flatten()
@classmethod
def chain_iterable_from_mmcif(
cls,
path: PathOrBuffer | MmcifWrapper,
id: str | None = None,
is_predicted: bool = False,
keep_source: bool = False,
):
"""Return a list[ProteinChain] object from an mmcif file, a iterable list of all protein chain
from an mmcif file
"""
if isinstance(path, MmcifWrapper):
mmcif = path
else:
mmcif = MmcifWrapper.read(path, id)
for chain in bs.chain_iter(mmcif.structure):
chain = chain[bs.filter_amino_acids(chain) & ~chain.hetero]
if len(chain) == 0:
continue
chain_id = chain.chain_id[0]
entity_id = None
for entity, chains in mmcif.entities.items():
if chain_id in chains:
entity_id = entity
assert entity_id is not None
(
sequence,
atom_positions,
atom_mask,
residue_index,
insertion_code,
confidence,
_,
) = chain_to_ndarray(chain, mmcif, chain_id, is_predicted)
assert all(sequence), "Some residue name was not specified correctly"
yield cls(
id=mmcif.id,
sequence=sequence,
chain_id=chain_id,
entity_id=entity_id,
atom37_positions=atom_positions,
atom37_mask=atom_mask,
residue_index=residue_index,
insertion_code=insertion_code,
confidence=confidence,
mmcif=mmcif if keep_source else None,
)
@classmethod
def from_mmcif(
cls,
path: PathOrBuffer | MmcifWrapper,
chain_id: str | None = None,
entity_id: int | None = None,
id: str | None = None,
is_predicted: bool = False,
keep_source: bool = False,
):
"""Return a ProteinChain object from an mmcif file.
Args:
path (str | Path | io.TextIO): Path or buffer to read mmcif file from. Should be uncompressed.
id (str, optional): String identifier to assign to structure. Will attempt to infer otherwise.
is_predicted (bool): If True, reads b factor as the confidence readout. Default: False.
chain_id (str, optional): Select a chain corresponding to (author) chain id.
entity_id (int, optional): Select a chain corresponding to a particular entity.
If neither `chain_id` nor `entity_id` is specified, defaults to the first entity.
"""
if isinstance(path, MmcifWrapper):
mmcif = path
else:
mmcif = MmcifWrapper.read(path, id)
# If neither chain_id nor entity_id is specified, default to the first entity
if chain_id is None and entity_id is None:
if not mmcif.entities:
raise ValueError("Structure contains no entities")
entity_id = min(mmcif.entities.keys()) # Pick the first entity by ID
if entity_id is not None:
assert chain_id is None
if entity_id not in mmcif.entities:
raise ValueError(
f"Structure does not contain entity `{entity_id}`. Valid entities: {mmcif.entities.keys()}"
)
chains = mmcif.entities[entity_id]
# Select the chain id corresponding to the longest chain. If all are equal length, selects the first.
chain_id = max(
chains,
key=lambda chain: _num_non_null_residues(
mmcif.seqres_to_structure[chain]
),
)
else:
assert chain_id is not None
for entity, chains in mmcif.entities.items():
if chain_id in chains:
entity_id = entity
if entity_id is None:
warnings.warn(
"Failed to detect entity_id from mmcif file, it may be malformed."
)
atom_array = mmcif.structure
(
sequence,
atom_positions,
atom_mask,
residue_index,
insertion_code,
confidence,
_,
) = chain_to_ndarray(atom_array, mmcif, chain_id, is_predicted)
assert all(sequence), "Some residue name was not specified correctly"
return cls(
id=mmcif.id,
sequence=sequence,
chain_id=chain_id,
entity_id=entity_id,
atom37_positions=atom_positions,
atom37_mask=atom_mask.astype(bool),
residue_index=residue_index,
insertion_code=insertion_code,
confidence=confidence,
mmcif=mmcif if keep_source else None,
)
@classmethod
def from_atom37(
cls,
@@ -549,11 +896,11 @@ class ProteinChain:
return cls(
id=id,
sequence=sequence,
sequence=sequence, # type: ignore
chain_id=chain_id,
entity_id=entity_id,
atom37_positions=atom37_positions,
atom37_mask=atom_mask,
atom37_mask=atom_mask.astype(bool),
residue_index=residue_index,
insertion_code=insertion_code,
confidence=confidence,
@@ -604,10 +951,12 @@ class ProteinChain:
id: str | None = None,
is_predicted: bool = False,
) -> "ProteinChain":
"""Return a ProteinChain object from an pdb file.
"""Return a ProteinChain object from an pdb file. NOTE: prefer mmcif for rcsb PDB files.
This function is mostly to interface with old PDB files and predicted structures -
it will not fill out the entity id correctly
Args:
path (str | Path | io.TextIO): Path or buffer to read pdb file from. Should be uncompressed.
path (str | Path | io.TextIO): Path or buffer to read mmcif file from. Should be uncompressed.
id (str, optional): String identifier to assign to structure. Will attempt to infer otherwise.
is_predicted (bool): If True, reads b factor as the confidence readout. Default: False.
chain_id (str, optional): Select a chain corresponding to (author) chain id. "detect" uses the
@@ -637,20 +986,17 @@ class ProteinChain:
entity_id = 1 # Not supplied in PDBfiles
sequence = "".join(
(
r
if len(r := PDBData.protein_letters_3to1.get(monomer[0].res_name, "X"))
== 1
else "X"
)
residue_constants.restype_3to1.get(monomer[0].res_name, "X")
for monomer in bs.residue_iter(atom_array)
)
num_res = len(sequence)
atom_positions = np.full(
[num_res, RC.atom_type_num, 3], np.nan, dtype=np.float32
[num_res, residue_constants.atom_type_num, 3], np.nan, dtype=np.float32
)
atom_mask = np.full(
[num_res, residue_constants.atom_type_num], False, dtype=bool
)
atom_mask = np.full([num_res, RC.atom_type_num], False, dtype=bool)
residue_index = np.full([num_res], -1, dtype=np.int64)
insertion_code = np.full([num_res], "", dtype="<U4")
@@ -671,9 +1017,11 @@ class ProteinChain:
# Put the coords of the selenium atom in the sulphur column
atom_name = "SD"
if atom_name in RC.atom_order:
atom_positions[i, RC.atom_order[atom_name]] = atom.coord
atom_mask[i, RC.atom_order[atom_name]] = True
if atom_name in residue_constants.atom_order:
atom_positions[i, residue_constants.atom_order[atom_name]] = (
atom.coord
)
atom_mask[i, residue_constants.atom_order[atom_name]] = True
if is_predicted and atom_name == "CA":
confidence[i] = atom.b_factor
@@ -685,21 +1033,49 @@ class ProteinChain:
chain_id=chain_id,
entity_id=entity_id,
atom37_positions=atom_positions,
atom37_mask=atom_mask,
atom37_mask=atom_mask.astype(bool),
residue_index=residue_index,
insertion_code=insertion_code,
confidence=confidence,
mmcif=None,
)
@classmethod
def from_rcsb(cls, pdb_id: str, chain_id: str = "detect"):
"""Fetch a protein chain from the RCSB PDB database."""
f: io.StringIO = rcsb.fetch(pdb_id, "pdb") # type: ignore
return cls.from_pdb(f, chain_id=chain_id, id=pdb_id)
def from_mds(cls, data: dict[str, Any]) -> "ProteinChain":
return cls(
id=data["id"],
chain_id=data["chain_id"],
entity_id=data["entity_id"],
sequence=data["sequence"],
residue_index=data["residue_index"],
insertion_code=np.asarray(data["insertion_code"]),
atom37_positions=data["atom37_positions"],
atom37_mask=data["atom37_mask"].astype(bool),
confidence=data["confidence"],
mmcif=None,
)
@classmethod
def from_rcsb(
cls,
pdb_id: str,
chain_id: str | None = None,
entity_id: int | None = None,
keep_source: bool = False,
) -> ProteinChain:
f: io.StringIO = rcsb.fetch(pdb_id, "cif") # type: ignore
return cls.from_mmcif(
f,
id=pdb_id,
chain_id=chain_id,
entity_id=entity_id,
keep_source=keep_source,
is_predicted=False,
)
@classmethod
def from_atomarray(
cls, atom_array: bs.AtomArray, id: str | None = None
cls, atom_array: bs.AtomArray, id: str | None = None, is_predicted: bool = False
) -> "ProteinChain":
"""A simple converter from bs.AtomArray -> ProteinChain.
Uses PDB file format as intermediate."""
@@ -711,7 +1087,7 @@ class ProteinChain:
buf = io.StringIO()
pdb_file.write(buf)
buf.seek(0)
return cls.from_pdb(buf, id=id)
return cls.from_pdb(buf, id=id, is_predicted=is_predicted)
def get_normalization_frame(self) -> Affine3D:
"""Given a set of coordinates, compute a single frame.
@@ -745,6 +1121,8 @@ class ProteinChain:
def infer_oxygen(self) -> ProteinChain:
"""Oxygen position is fixed given N, CA, C atoms. Infer it if not provided."""
O_missing_indices = np.argwhere(np.isnan(self.atoms["O"]).any(axis=1)).squeeze()
O_vector = torch.tensor([0.6240, -1.0613, 0.0103], dtype=torch.float32)
N, CA, C = torch.from_numpy(self.atoms[["N", "CA", "C"]]).float().unbind(dim=1)
N = torch.roll(N, -3)
@@ -756,9 +1134,11 @@ class ProteinChain:
atom37_positions = self.atom37_positions.copy()
atom37_mask = self.atom37_mask.copy()
atom37_positions[:, RC.atom_order["O"]] = O.numpy()
atom37_mask[:, RC.atom_order["O"]] = ~np.isnan(
atom37_positions[:, RC.atom_order["O"]]
atom37_positions[O_missing_indices, residue_constants.atom_order["O"]] = O[
O_missing_indices
].numpy()
atom37_mask[O_missing_indices, residue_constants.atom_order["O"]] = ~np.isnan(
atom37_positions[O_missing_indices, residue_constants.atom_order["O"]]
).any(-1)
new_chain = replace(
self, atom37_positions=atom37_positions, atom37_mask=atom37_mask
@@ -781,7 +1161,7 @@ class ProteinChain:
infer_cbeta_for_glycine (bool): If True, infers a beta carbon for glycine
residues, even though that residue doesn't have one. Default off.
NOTE: The reason for having this switch in the first place
NOTE(rverkuil): The reason for having this switch in the first place
is that sometimes we want a (inferred) CB coordinate for every residue,
for example for making a pairwise distance matrix, or doing an RMSD
calculation between two designs for a given structural template, w/
@@ -792,11 +1172,13 @@ class ProteinChain:
inferred_cbeta_positions = self.inferred_cbeta
if not infer_cbeta_for_glycine:
inferred_cbeta_positions[np.array(list(self.sequence)) == "G", :] = np.NAN
inferred_cbeta_positions[np.array(list(self.sequence)) == "G", :] = np.nan
atom37_positions[:, RC.atom_order["CB"]] = inferred_cbeta_positions
atom37_mask[:, RC.atom_order["CB"]] = ~np.isnan(
atom37_positions[:, RC.atom_order["CB"]]
atom37_positions[:, residue_constants.atom_order["CB"]] = (
inferred_cbeta_positions
)
atom37_mask[:, residue_constants.atom_order["CB"]] = ~np.isnan(
atom37_positions[:, residue_constants.atom_order["CB"]]
).any(-1)
new_chain = replace(
self, atom37_positions=atom37_positions, atom37_mask=atom37_mask
@@ -822,36 +1204,69 @@ class ProteinChain:
)
@classmethod
def concat(cls, chains: Sequence[ProteinChain]):
def join_arrays(arrays: Sequence[np.ndarray], sep: np.ndarray):
full_array = []
for array in arrays:
full_array.append(array)
full_array.append(sep)
full_array = full_array[:-1]
return np.concatenate(full_array, 0)
def concat(cls, chains: Sequence[ProteinChain], use_chainbreak: bool = True):
sep_tokens = {
"residue_index": np.array([-1]),
"insertion_code": np.array([""]),
"atom37_positions": np.full([1, 37, 3], np.nan),
"atom37_mask": np.zeros([1, 37]),
"atom37_positions": np.full([1, 37, 3], np.inf),
"atom37_mask": np.zeros([1, 37], dtype=bool),
"confidence": np.array([0]),
}
def join_arrays(arrays: Sequence[np.ndarray], sep: np.ndarray):
if use_chainbreak:
full_array = []
for array in arrays:
full_array.append(array)
full_array.append(sep)
full_array = full_array[:-1]
return np.concatenate(full_array, 0)
else:
return np.concatenate(arrays, 0)
array_args: dict[str, np.ndarray] = {
name: join_arrays([getattr(chain, name) for chain in chains], sep)
for name, sep in sep_tokens.items()
}
chain_break = residue_constants.CHAIN_BREAK_TOKEN if use_chainbreak else ""
return cls(
id=chains[0].id,
sequence=C.CHAIN_BREAK_STR.join(chain.sequence for chain in chains),
sequence=chain_break.join(chain.sequence for chain in chains),
chain_id="A",
entity_id=None,
mmcif=None,
**array_args,
)
def find_nonpolymer_contacts(self):
assert self.mmcif is not None
nonpolymer_and_chain_id_to_array = self.mmcif.non_polymer_coords
results = []
for (
nonpolymer,
_,
), nonpolymer_array in nonpolymer_and_chain_id_to_array.items():
assert nonpolymer_array.coord is not None
chain_coords = self.atom37_positions[self.atom37_mask]
distance = cdist(nonpolymer_array.coord, chain_coords)
is_contact = distance < 5
if not is_contact.any():
continue
contacting_atoms = np.where(is_contact.any(0))[0]
chain_index = np.where(self.atom37_mask)[0]
contacting_residues = np.unique(chain_index[contacting_atoms])
result = {
"ligand": nonpolymer.name,
"ligand_id": nonpolymer.comp_id,
"contacting_residues": contacting_residues.tolist(),
}
results.append(result)
return results
def select_residue_indices(
self, indices: list[int | str], ignore_x_mismatch: bool = False
) -> ProteinChain:
@@ -876,3 +1291,25 @@ class ProteinChain:
raise RuntimeError(mismatch_str)
return new
def to_structure_encoder_inputs(
self,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Convert protein chain to structure encoder inputs.
Returns:
tuple: (coordinates, plddt, residue_index) where:
- coordinates: (1, L, 37, 3) tensor of atom positions
- plddt: (1, L) tensor of confidence scores
- residue_index: (1, L) tensor of residue indices
"""
# Convert to tensors and add batch dimension
coordinates = (
torch.from_numpy(self.atom37_positions).float().unsqueeze(0)
) # (1, L, 37, 3)
plddt = torch.from_numpy(self.confidence).float().unsqueeze(0) # (1, L)
residue_index = (
torch.from_numpy(self.residue_index).long().unsqueeze(0)
) # (1, L)
return coordinates, plddt, residue_index

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import io
import itertools
import random
import re
import warnings
from dataclasses import asdict, dataclass, replace
@@ -18,24 +19,28 @@ import msgpack_numpy
import numpy as np
import torch
from biotite.database import rcsb
from biotite.file import InvalidFileError
from biotite.structure.io.pdb import PDBFile
from biotite.structure.io.pdbx import CIFCategory, CIFColumn, CIFData, CIFFile
from biotite.structure.io.pdbx import set_structure as set_structure_pdbx
from biotite.structure.io.pdbx.convert import _get_transformations, get_structure
from biotite.structure.util import matrix_rotate
from scipy.spatial import KDTree
from esm.utils import residue_constants
from esm.utils.constants import esm3 as esm3_c
from esm.utils.misc import slice_python_object_as_numpy
from esm.utils.structure.affine3d import Affine3D
from esm.utils.structure.aligner import Aligner
from esm.utils.structure.metrics import (
compute_gdt_ts,
compute_lddt_ca,
)
from esm.utils.structure.atom_indexer import AtomIndexer
from esm.utils.structure.metrics import compute_gdt_ts, compute_lddt_ca
from esm.utils.structure.mmcif_parsing import MmcifWrapper, NoProteinError
from esm.utils.structure.protein_chain import (
PathOrBuffer,
ProteinChain,
)
from esm.utils.structure.protein_structure import (
chain_to_ndarray,
index_by_atom_name,
infer_CB,
)
from esm.utils.types import PathOrBuffer
msgpack_numpy.patch()
@@ -44,35 +49,72 @@ SINGLE_LETTER_CHAIN_IDS = (
)
def protein_chain_to_protein_complex(chain: ProteinChain) -> ProteinComplex:
if "|" not in chain.sequence:
return ProteinComplex.from_chains([chain])
chain_breaks = np.array(list(chain.sequence)) == "|"
chain_break_inds = np.where(chain_breaks)[0]
chain_break_inds = np.concatenate([[0], chain_break_inds, [len(chain)]])
chain_break_inds = np.array(list(zip(chain_break_inds[:-1], chain_break_inds[1:])))
complex_chains = []
for start, end in chain_break_inds:
if start != 0:
start += 1
complex_chains.append(chain[start:end])
complex_chains = [
ProteinChain.from_atom37(
chain.atom37_positions,
sequence=chain.sequence,
chain_id=SINGLE_LETTER_CHAIN_IDS[i],
entity_id=i,
)
for i, chain in enumerate(complex_chains)
]
return ProteinComplex.from_chains(complex_chains)
def _parse_operation_expression(expression):
"""
Get successive operation steps (IDs) for the given
``oper_expression``.
Form the cartesian product, if necessary.
Copied from biotite and fixed a bug
"""
# Split groups by parentheses:
# use the opening parenthesis as delimiter
# and just remove the closing parenthesis
expressions_per_step = expression.replace(")", "").split("(")
expressions_per_step = [e for e in expressions_per_step if len(e) > 0]
# Important: Operations are applied from right to left
expressions_per_step.reverse()
operations = []
for expr in expressions_per_step:
cur_expr = expr.split(",")
cur_op = []
# Deal with e='1-10,20-30,40-50' type expressions
for e in cur_expr:
if "-" in e:
first, last = e.split("-")
cur_op.extend(str(id) for id in range(int(first), int(last) + 1))
else:
cur_op.append(e)
operations.append(cur_op)
# Cartesian product of operations
return list(itertools.product(*operations))
def _apply_transformations_fast(chains, transformation_dict, operations):
"""
Get subassembly by applying the given operations to the input
structure containing affected asym IDs.
"""
# Additional first dimesion for 'structure.repeat()'
results = []
# Apply corresponding transformation for each copy in the assembly
for c in chains:
for operation in operations:
coord = c.atom37_positions.copy()
# Execute for each transformation step
# in the operation expression
for op_step in operation:
T = transformation_dict[op_step]
# Rotate
coord = matrix_rotate(coord, T.rotation)
# Translate
coord += T.target_translation
new_chain = replace(c, atom37_positions=coord)
results.append(new_chain)
return results
@dataclass
class ProteinComplexMetadata:
entity_lookup: dict[int, int]
chain_lookup: dict[int, str]
chain_boundaries: list[tuple[int, int]]
mmcif: MmcifWrapper | None = None
# This is a dictionary that maps assembly ids to the list of unique chains
# in that assembly. Allows for usage of `switch_assembly`.
assembly_composition: dict[str, list[str]] | None = None
@dataclass
@@ -100,19 +142,7 @@ class DockQResult:
aligned_rmsd: float
class AtomIndexer:
def __init__(self, structure: ProteinComplex, property: str, dim: int):
self.structure = structure
self.property = property
self.dim = dim
def __getitem__(self, atom_names: str | list[str]) -> np.ndarray:
return index_by_atom_name(
getattr(self.structure, self.property), atom_names, self.dim
)
@dataclass
@dataclass(frozen=True)
class ProteinComplex:
"""Dataclass with atom37 representation of an entire protein complex."""
@@ -126,6 +156,7 @@ class ProteinComplex:
atom37_positions: np.ndarray
atom37_mask: np.ndarray
confidence: np.ndarray
# This metadata is parsed from the MMCIF file. For synthetic data, we do a best effort.
metadata: ProteinComplexMetadata
def __post_init__(self):
@@ -195,15 +226,59 @@ class ProteinComplex:
def __len__(self):
return len(self.sequence)
@property
def num_chains(self):
return len(self.chain_boundaries)
@cached_property
def atoms(self) -> AtomIndexer:
return AtomIndexer(self, property="atom37_positions", dim=-2)
@cached_property
def atom_mask(self) -> AtomIndexer:
return AtomIndexer(self, property="atom37_mask", dim=-1)
@cached_property
def chain_lengths(self) -> np.ndarray:
return np.diff(self.chain_boundaries, axis=1).flatten()
@cached_property
def chain_boundaries(self) -> list[tuple[int, int]]:
cb = [-1]
for i, s in enumerate(self.sequence):
if s == "|":
cb.append(i)
cb.append(len(self))
return [(cb[i] + 1, cb[i + 1]) for i in range(len(cb) - 1)]
def get_chain_by_index(self, index: int) -> ProteinChain:
try:
start, end = self.chain_boundaries[index]
return self[start:end].as_chain()
except IndexError:
raise IndexError(f"Chain index {index} out of bounds")
def get_chain_by_id(
self, chain_id: str, sample_chain_if_duplicate: bool = True
) -> ProteinChain:
valid_indices = [
index
for index, id_of_index in self.metadata.chain_lookup.items()
if id_of_index == chain_id
]
if not valid_indices:
raise KeyError(f"Chain ID {chain_id} not found")
if sample_chain_if_duplicate:
index_to_return = random.choice(valid_indices)
return self.get_chain_by_index(index_to_return)
else:
if len(valid_indices) > 1:
raise ValueError(f"Multiple chains with chain ID {chain_id} found")
return self.get_chain_by_index(valid_indices[0])
def chain_iter(self) -> Iterable[ProteinChain]:
boundaries = [i for i, s in enumerate(self.sequence) if s == "|"]
boundaries = [-1, *boundaries, len(self)]
for i in range(len(boundaries) - 1):
c = self.__getitem__(slice(boundaries[i] + 1, boundaries[i + 1]))
for start, end in self.chain_boundaries:
c = self[start:end]
yield c.as_chain()
def as_chain(self, force_conversion: bool = False) -> ProteinChain:
@@ -237,6 +312,7 @@ class ProteinComplex:
residue_index=self.residue_index,
insertion_code=self.insertion_code,
confidence=self.confidence,
mmcif=self.metadata.mmcif,
)
@classmethod
@@ -253,12 +329,6 @@ class ProteinComplex:
chains.append(ProteinChain.from_atomarray(chain, id))
return ProteinComplex.from_chains(chains)
@classmethod
def from_rcsb(cls, pdb_id: str):
"""Fetch a protein complex from the RCSB PDB database."""
f: io.StringIO = rcsb.fetch(pdb_id, "pdb") # type: ignore
return cls.from_pdb(f, id=pdb_id)
def to_pdb(self, path: PathOrBuffer, include_insertions: bool = True):
atom_array = None
for chain in self.chain_iter():
@@ -284,13 +354,29 @@ class ProteinComplex:
ids = SINGLE_LETTER_CHAIN_IDS
chains = []
for i, chain in enumerate(self.chain_iter()):
chain.chain_id = ids[i]
chain = replace(chain, chain_id=ids[i])
if i > len(ids):
raise RuntimeError("Too many chains to write to PDB file")
chains.append(chain)
return ProteinComplex.from_chains(chains)
def find_assembly_ids_with_chain(self, id: str) -> list[str]:
good_chains = []
if (comp := self.metadata.assembly_composition) is not None:
for assembly_id, chain_ids in comp.items():
if id in chain_ids:
good_chains.append(assembly_id)
else:
raise ValueError(
"Cannot switch assemblies on this ProteinComplex, you must create the assembly from mmcif to support this"
)
return good_chains
def switch_assembly(self, id: str):
assert self.metadata.mmcif is not None
return get_assembly_fast(self.metadata.mmcif, assembly_id=id)
def state_dict(self, backbone_only=False):
"""This state dict is optimized for storage, so it turns things to fp16 whenever
possible. Note that we also only support int32 residue indices, I'm hoping we don't
@@ -308,6 +394,11 @@ class ProteinComplex:
elif isinstance(v, ProteinComplexMetadata):
dct[k] = asdict(v)
dct["atom37_positions"] = dct["atom37_positions"][dct["atom37_mask"]]
dct["metadata"]["mmcif"] = None
# These can be populated with non-serializable objects and are not needed for reconstruction
dct.pop("atoms", None)
dct.pop("atom_mask", None)
dct.pop("per_chain_kd_trees", None)
return dct
def to_blob(self, backbone_only=False) -> bytes:
@@ -322,6 +413,10 @@ class ProteinComplex:
k: (v.astype(np.float32) if k in ["atom37_positions", "confidence"] else v)
for k, v in dct.items()
}
if "chain_boundaries" in dct:
del dct["chain_boundaries"]
if "chain_boundaries" in dct["metadata"]:
del dct["metadata"]["chain_boundaries"]
dct["metadata"] = ProteinComplexMetadata(**dct["metadata"])
return cls(**dct)
@@ -342,13 +437,45 @@ class ProteinComplex:
)
@classmethod
def from_chains(cls, chains: Sequence[ProteinChain]):
def from_rcsb(cls, pdb_id: str, keep_source: bool = False) -> ProteinComplex:
f: io.StringIO = rcsb.fetch(pdb_id, "cif") # type: ignore
return cls.from_mmcif(f, id=pdb_id, keep_source=keep_source, is_predicted=False)
@classmethod
def from_mmcif(
cls,
path: PathOrBuffer,
id: str | None = None,
assembly_id: str | None = None,
is_predicted: bool = False,
keep_source: bool = False,
):
"""Return a ProteinComplex object from an mmcif file.
TODO(@zeming): there's actually multiple complexes per file, but for ease of implementation,
we only consider the first defined complex!
Args:
path (str | Path | io.TextIO): Path or buffer to read mmcif file from. Should be uncompressed.
id (str, optional): String identifier to assign to structure. Will attempt to infer otherwise.
is_predicted (bool): If True, reads b factor as the confidence readout. Default: False.
chain_id (str, optional): Select a chain corresponding to (author) chain id.
"""
mmcif = MmcifWrapper.read(path, id)
return get_assembly_fast(mmcif, assembly_id=assembly_id)
@classmethod
def from_chains(
cls,
chains: Sequence[ProteinChain],
mmcif: MmcifWrapper | None = None,
all_assembly_metadata_dictionary: dict[str, list[str]] | None = None,
):
if not chains:
raise ValueError(
"Cannot create a ProteinComplex from an empty list of chains"
)
# TODO: Make a proper protein complex class
# TODO(roshan): Make a proper protein complex class
def join_arrays(arrays: Sequence[np.ndarray], sep: np.ndarray):
full_array = []
for array in arrays:
@@ -376,7 +503,6 @@ class ProteinComplex:
ent2num_max = -1
ent2num = {}
total_index = 0
chain_boundaries = []
for i, c in enumerate(chains):
num_res = c.residue_index.shape[0]
if c.chain_id not in chain2num:
@@ -401,7 +527,6 @@ class ProteinComplex:
}
)
chain_boundaries.append((total_index, total_index + num_res))
total_index += num_res + 1
sep = np.array([-1])
@@ -412,20 +537,25 @@ class ProteinComplex:
array_args.update(update)
metadata = ProteinComplexMetadata(
chain_boundaries=chain_boundaries,
mmcif=mmcif,
chain_lookup={v: k for k, v in chain2num.items()},
entity_lookup={v: k for k, v in ent2num.items()},
assembly_composition=all_assembly_metadata_dictionary,
)
return cls(
id=chains[0].id,
sequence=esm3_c.CHAIN_BREAK_STR.join(chain.sequence for chain in chains),
sequence=residue_constants.CHAIN_BREAK_TOKEN.join(
chain.sequence for chain in chains
),
metadata=metadata,
**array_args,
)
def infer_oxygen(self) -> ProteinComplex:
"""Oxygen position is fixed given N, CA, C atoms. Infer it if not provided."""
O_missing_indices = np.argwhere(np.isnan(self.atoms["O"]).any(axis=1)).squeeze()
O_vector = torch.tensor([0.6240, -1.0613, 0.0103], dtype=torch.float32)
N, CA, C = torch.from_numpy(self.atoms[["N", "CA", "C"]]).float().unbind(dim=1)
N = torch.roll(N, -3)
@@ -437,15 +567,56 @@ class ProteinComplex:
atom37_positions = self.atom37_positions.copy()
atom37_mask = self.atom37_mask.copy()
atom37_positions[:, residue_constants.atom_order["O"]] = O.numpy()
atom37_mask[:, residue_constants.atom_order["O"]] = ~np.isnan(
atom37_positions[:, residue_constants.atom_order["O"]]
atom37_positions[O_missing_indices, residue_constants.atom_order["O"]] = O[
O_missing_indices
].numpy()
atom37_mask[O_missing_indices, residue_constants.atom_order["O"]] = ~np.isnan(
atom37_positions[O_missing_indices, residue_constants.atom_order["O"]]
).any(-1)
new_chain = replace(
self, atom37_positions=atom37_positions, atom37_mask=atom37_mask
)
return new_chain
def infer_cbeta(self, infer_cbeta_for_glycine: bool = False) -> ProteinComplex:
"""Return a new chain with inferred CB atoms at all residues except GLY.
Args:
infer_cbeta_for_glycine (bool): If True, infers a beta carbon for glycine
residues, even though that residue doesn't have one. Default off.
NOTE(rverkuil): The reason for having this switch in the first place
is that sometimes we want a (inferred) CB coordinate for every residue,
for example for making a pairwise distance matrix, or doing an RMSD
calculation between two designs for a given structural template, w/
CB atoms.
"""
atom37_positions = self.atom37_positions.copy()
atom37_mask = self.atom37_mask.copy()
N, CA, C = np.moveaxis(self.atoms[["N", "CA", "C"]], 1, 0)
# See usage in trDesign codebase.
# https://github.com/gjoni/trDesign/blob/f2d5930b472e77bfacc2f437b3966e7a708a8d37/02-GD/utils.py#L140
inferred_cbeta_positions = infer_CB(C, N, CA, 1.522, 1.927, -2.143)
if not infer_cbeta_for_glycine:
inferred_cbeta_positions[np.array(list(self.sequence)) == "G", :] = np.nan
atom37_positions[:, residue_constants.atom_order["CB"]] = (
inferred_cbeta_positions
)
atom37_mask[:, residue_constants.atom_order["CB"]] = ~np.isnan(
atom37_positions[:, residue_constants.atom_order["CB"]]
).any(-1)
new_chain = replace(
self, atom37_positions=atom37_positions, atom37_mask=atom37_mask
)
return new_chain
@classmethod
def from_open_source(cls, pc: ProteinComplex):
# TODO(@zeming): deprecated, should delete
return pc
@classmethod
def concat(cls, objs: list[ProteinComplex]) -> ProteinComplex:
pdb_ids = [obj.id for obj in objs]
@@ -463,6 +634,50 @@ class ProteinComplex:
list(other.chain_iter())
), "Protein complexes must have the same number of chains"
def rmsd(
self,
target: ProteinComplex,
also_check_reflection: bool = False,
mobile_inds: list[int] | np.ndarray | None = None,
target_inds: list[int] | np.ndarray | None = None,
only_compute_backbone_rmsd: bool = False,
compute_chain_assignment: bool = True,
):
"""
Compute the RMSD between this protein chain and another.
Args:
target (ProteinComplex): The target (other) protein complex to compare to.
also_check_reflection (bool, optional): If True, also check if the reflection of the mobile atoms has a lower RMSD.
mobile_inds (list[int], optional): The indices of the mobile atoms to align. These are NOT residue indices
target_inds (list[int], optional): The indices of the target atoms to align. These are NOT residue indices
only_compute_backbone_rmsd (bool, optional): If True, only compute the RMSD of the backbone atoms.
"""
if compute_chain_assignment:
aligned = self.dockq(target).aligned
else:
aligned = self
aligner = Aligner(
aligned if mobile_inds is None else aligned[mobile_inds],
target if target_inds is None else target[target_inds],
only_compute_backbone_rmsd,
)
avg_rmsd = aligner.rmsd
if not also_check_reflection:
return avg_rmsd
aligner = Aligner(
aligned if mobile_inds is None else aligned[mobile_inds],
target if target_inds is None else target[target_inds],
only_compute_backbone_rmsd,
use_reflection=True,
)
avg_rmsd_neg = aligner.rmsd
return min(avg_rmsd, avg_rmsd_neg)
def lddt_ca(
self,
target: ProteinComplex,
@@ -537,6 +752,10 @@ class ProteinComplex:
# This function uses dockqv2 to compute the DockQ score. Because it does a mapping
# over all possible chains, it's quite slow. Be careful not to use this in an inference loop
# or something that requires fast scoring. It defaults to 8 CPUs.
#
# TODO(@zeming): Because we haven't properly implemented protein complexes for mmcif,
# if your protein has multi-letter or repeated chain IDs, this will fail. Please call
# pc = pc.normalize_chain_ids_for_pdb() before calling this function in that case (limit is 62 chains)
try:
pass
@@ -658,3 +877,316 @@ class ProteinComplex:
)
return result
@cached_property
def per_chain_kd_trees(self):
# Iterate over chains, build KDTree for each chain
kdtrees = []
CA = self.atoms["CA"]
for start, end in self.chain_boundaries:
chain_CA = CA[start:end]
chain_CA = chain_CA[np.isfinite(chain_CA).all(axis=-1)]
kdtrees.append(KDTree(chain_CA))
return kdtrees
def chain_adjacency(self, cutoff: float = 8.0) -> np.ndarray:
# Compute adjacency matrix for protein complex
num_chains = self.num_chains
adjacency = np.zeros((num_chains, num_chains), dtype=bool)
for (i, kdtree), (j, kdtree2) in itertools.combinations(
enumerate(self.per_chain_kd_trees), 2
):
adj = kdtree.query_ball_tree(kdtree2, cutoff)
any_is_adjacent = any(len(a) > 0 for a in adj)
adjacency[i, j] = any_is_adjacent
adjacency[j, i] = any_is_adjacent
return adjacency
def chain_adjacency_by_index(self, index: int, cutoff: float = 8.0) -> np.ndarray:
num_chains = len(self.chain_boundaries)
adjacency = np.zeros(num_chains, dtype=bool)
for i, kdtree in enumerate(self.per_chain_kd_trees):
if i == index:
continue
adj = kdtree.query_ball_tree(self.per_chain_kd_trees[index], cutoff)
adjacency[i] = any(len(a) > 0 for a in adj)
return adjacency
def add_prefix_to_chain_ids(self, prefix: str) -> ProteinComplex:
"""Rename all chains in the complex with a given prefix.
Args:
prefix (str): The prefix to use for the new chain IDs. Each chain will be
named as "{prefix}_{chain_id}".
Returns:
ProteinComplex: A new protein complex with renamed chains.
"""
new_chains = []
for chain in self.chain_iter():
# Create new chain with updated chain_id
new_chain = replace(chain, chain_id=f"{prefix}_{chain.chain_id}")
new_chains.append(new_chain)
return ProteinComplex.from_chains(new_chains)
def sasa(self, by_residue: bool = True):
chain = self.as_chain(force_conversion=True)
return chain.sasa(by_residue=by_residue)
def to_mmcif_string(self) -> str:
"""Convert the ProteinComplex to mmCIF format.
Returns:
str: The mmCIF content as a string.
"""
# Convert the ProteinComplex to a biotite AtomArray
# Collect all atoms from all chains
all_atoms = []
for chain in self.chain_iter():
chain_atom_array = chain.atom_array
# Convert AtomArray to list of atoms and add to collection
all_atoms.extend(chain_atom_array)
# Create combined AtomArray from all atoms
if not all_atoms:
raise ValueError("No atoms found in protein complex")
atom_array = bs.array(all_atoms)
# Create CIF file
f = CIFFile()
set_structure_pdbx(f, atom_array, data_block=self.id)
# Add entity information for proper mmCIF structure
self._add_entity_information(f)
# Write to string
output = io.StringIO()
f.write(output)
return output.getvalue()
def _add_entity_information(self, cif_file: CIFFile) -> None:
"""Add entity, entity_poly, and struct_asym sections to CIF file."""
# Group chains by sequence to create unique entities
entity_map = {} # sequence -> entity_id
chain_to_entity = {} # chain_id -> entity_id
entity_sequences = {} # entity_id -> sequence
entity_id_counter = 1
for chain in self.chain_iter():
sequence = chain.sequence
if sequence not in entity_map:
entity_map[sequence] = entity_id_counter
entity_sequences[entity_id_counter] = sequence
entity_id_counter += 1
chain_to_entity[chain.chain_id] = entity_map[sequence]
# Create _entity section
entity_ids = []
entity_types = []
entity_descriptions = []
for entity_id in sorted(entity_sequences.keys()):
entity_ids.append(str(entity_id))
entity_types.append("polymer")
entity_descriptions.append(f"Protein chain (entity {entity_id})")
cif_file.block["entity"] = CIFCategory(
name="entity",
columns={
"id": CIFColumn(
data=CIFData(array=np.array(entity_ids), dtype=np.str_)
),
"type": CIFColumn(
data=CIFData(array=np.array(entity_types), dtype=np.str_)
),
"pdbx_description": CIFColumn(
data=CIFData(array=np.array(entity_descriptions), dtype=np.str_)
),
},
)
# Create _entity_poly section
poly_entity_ids = []
poly_types = []
poly_nstd_linkages = []
poly_sequences = []
for entity_id in sorted(entity_sequences.keys()):
poly_entity_ids.append(str(entity_id))
poly_types.append("polypeptide(L)")
poly_nstd_linkages.append("no")
poly_sequences.append(entity_sequences[entity_id])
cif_file.block["entity_poly"] = CIFCategory(
name="entity_poly",
columns={
"entity_id": CIFColumn(
data=CIFData(array=np.array(poly_entity_ids), dtype=np.str_)
),
"type": CIFColumn(
data=CIFData(array=np.array(poly_types), dtype=np.str_)
),
"nstd_linkage": CIFColumn(
data=CIFData(array=np.array(poly_nstd_linkages), dtype=np.str_)
),
"pdbx_seq_one_letter_code": CIFColumn(
data=CIFData(array=np.array(poly_sequences), dtype=np.str_)
),
},
)
# Create _struct_asym section
asym_ids = []
asym_entity_ids = []
asym_details = []
for chain in self.chain_iter():
asym_ids.append(chain.chain_id)
asym_entity_ids.append(str(chain_to_entity[chain.chain_id]))
asym_details.append("")
cif_file.block["struct_asym"] = CIFCategory(
name="struct_asym",
columns={
"id": CIFColumn(data=CIFData(array=np.array(asym_ids), dtype=np.str_)),
"entity_id": CIFColumn(
data=CIFData(array=np.array(asym_entity_ids), dtype=np.str_)
),
"details": CIFColumn(
data=CIFData(array=np.array(asym_details), dtype=np.str_)
),
},
)
def get_assembly_fast(
mmcif: MmcifWrapper,
assembly_id=None,
model=None,
data_block=None,
altloc="first",
use_author_fields=True,
):
pdbx_file = mmcif.raw
if pdbx_file is None:
raise InvalidFileError("No mmCIF data loaded")
assembly_gen_category = pdbx_file.block["pdbx_struct_assembly_gen"]
if assembly_gen_category is None:
raise InvalidFileError("File has no 'pdbx_struct_assembly_gen' category")
struct_oper_category = pdbx_file.block["pdbx_struct_oper_list"]
if struct_oper_category is None:
raise InvalidFileError("File has no 'pdbx_struct_oper_list' category")
if assembly_id is None:
assembly_id = assembly_gen_category["assembly_id"].data.array[0]
elif assembly_id not in assembly_gen_category["assembly_id"].data.array:
raise KeyError(f"File has no Assembly ID '{assembly_id}'")
### Calculate all possible transformations
transformations = _get_transformations(struct_oper_category)
### Get structure according to additional parameters
structure = get_structure(
pdbx_file, model, data_block, altloc, ["label_asym_id"], use_author_fields
)[0] # type: ignore
# TODO(@zeming) This line will remove all non-protein structural elements,
# we should remove this when we want to parse these too.
structure: bs.AtomArray = structure[
bs.filter_amino_acids(structure) & ~structure.hetero # type: ignore
]
if len(structure) == 0:
raise NoProteinError
unique_asym_ids = np.unique(structure.label_asym_id) # type: ignore
asym2chain = {}
asym2auth = {}
for asym_id in unique_asym_ids:
sub_structure: bs.AtomArray = structure[structure.label_asym_id == asym_id] # type: ignore
chain_id: str = sub_structure[0].chain_id # type: ignore
(
sequence,
atom_positions,
atom_mask,
residue_index,
insertion_code,
confidence,
entity_id,
) = chain_to_ndarray(sub_structure, mmcif, chain_id, False)
asym2chain[asym_id] = ProteinChain(
id=mmcif.id or "unknown",
sequence=sequence,
chain_id=chain_id,
entity_id=entity_id,
atom37_positions=atom_positions,
atom37_mask=atom_mask,
residue_index=residue_index,
insertion_code=insertion_code,
confidence=confidence,
mmcif=None,
)
asym2auth[asym_id] = chain_id
### Get transformations and apply them to the affected asym IDs
assembly = []
assembly_id_dict: dict[str, list[str]] = {}
# Process the target assembly ID
for aid, op_expr, asym_id_expr in zip(
assembly_gen_category["assembly_id"].data.array,
assembly_gen_category["oper_expression"].data.array,
assembly_gen_category["asym_id_list"].data.array,
):
if aid == assembly_id:
# Parse operations and asym IDs for this specific entry
operations = _parse_operation_expression(op_expr)
asym_ids = asym_id_expr.split(",")
# Filter affected asym IDs to only protein chains, preserving order
sub_structures = [
asym2chain[asym_id] for asym_id in asym_ids if asym_id in asym2chain
]
# Apply transformations
sub_assembly = _apply_transformations_fast(
sub_structures, transformations, operations
)
assembly.extend(sub_assembly)
# Build assembly_id_dict for this entry
assembly_id_dict[aid] = assembly_id_dict.get(aid, []) + [
asym2auth[id_] for id_ in asym_ids if id_ in asym2auth
]
if len(assembly) == 0:
raise NoProteinError
return ProteinComplex.from_chains(assembly, mmcif, assembly_id_dict)
def protein_chain_to_protein_complex(chain: ProteinChain) -> ProteinComplex:
if "|" not in chain.sequence:
return ProteinComplex.from_chains([chain])
chain_breaks = np.array(list(chain.sequence)) == "|"
chain_break_inds = np.where(chain_breaks)[0]
chain_break_inds = np.concatenate([[0], chain_break_inds, [len(chain)]])
chain_break_inds = np.array(list(zip(chain_break_inds[:-1], chain_break_inds[1:])))
complex_chains = []
for start, end in chain_break_inds:
if start != 0:
start += 1
complex_chains.append(chain[start:end])
complex_chains = [
ProteinChain.from_atom37(
chain.atom37_positions,
sequence=chain.sequence,
chain_id=SINGLE_LETTER_CHAIN_IDS[i],
entity_id=i,
)
for i, chain in enumerate(complex_chains)
]
return ProteinComplex.from_chains(complex_chains)

View File

@@ -194,7 +194,7 @@ def compute_rmsd_no_alignment(
mean_squared_error = diff.square().view(diff.size(0), -1, 9).mean(dim=-1)
else:
mean_squared_error = diff.square().sum(dim=(1, 2)) / (
num_valid_atoms.squeeze(-1) * 3
num_valid_atoms.squeeze(-1)
)
rmsd = torch.sqrt(mean_squared_error)

View File

@@ -4,9 +4,7 @@ import pygtrie
from ipywidgets import widgets
from esm.sdk.api import FunctionAnnotation
from esm.tokenization.function_tokenizer import (
InterProQuantizedTokenizer,
)
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
TRIE: pygtrie.CharTrie | None = None

View File

@@ -7,15 +7,11 @@ import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
from esm.sdk.api import ESMProtein
from esm.widgets.utils.drawing.draw_category_array import (
draw_data_array,
)
from esm.widgets.utils.drawing.draw_category_array import draw_data_array
from esm.widgets.utils.drawing.draw_function_annotations import (
draw_function_annotations,
)
from esm.widgets.utils.drawing.draw_protein_structure import (
draw_protein_structure,
)
from esm.widgets.utils.drawing.draw_protein_structure import draw_protein_structure
from esm.widgets.utils.serialization import (
create_download_button_from_buffer,
protein_to_pdb_buffer,

View File

@@ -3,16 +3,9 @@ from typing import Any, Callable, Sequence
import ipywidgets as widgets
from esm.utils.structure.protein_chain import ProteinChain
from esm.widgets.utils.drawing.colors import (
hex_to_rgba_tuple,
rgba_tuple_to_hex,
)
from esm.widgets.utils.drawing.draw_category_array import (
draw_data_array,
)
from esm.widgets.utils.parsing import (
convert_range_string_to_list_of_ranges,
)
from esm.widgets.utils.drawing.colors import hex_to_rgba_tuple, rgba_tuple_to_hex
from esm.widgets.utils.drawing.draw_category_array import draw_data_array
from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges
from esm.widgets.utils.prompting import PromptManager

View File

@@ -4,16 +4,9 @@ import ipywidgets as widgets
import pydssp
from esm.utils.structure.protein_chain import ProteinChain
from esm.widgets.utils.drawing.colors import (
hex_to_rgba_tuple,
rgba_tuple_to_hex,
)
from esm.widgets.utils.drawing.draw_category_array import (
draw_data_array,
)
from esm.widgets.utils.parsing import (
convert_range_string_to_list_of_ranges,
)
from esm.widgets.utils.drawing.colors import hex_to_rgba_tuple, rgba_tuple_to_hex
from esm.widgets.utils.drawing.draw_category_array import draw_data_array
from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges
from esm.widgets.utils.prompting import PromptManager

View File

@@ -6,9 +6,7 @@ from esm.widgets.utils.drawing.colors import (
hex_to_rgba_tuple,
rgba_tuple_to_rgba_html_string,
)
from esm.widgets.utils.parsing import (
convert_range_string_to_list_of_ranges,
)
from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges
from esm.widgets.utils.prompting import PromptManager

View File

@@ -10,12 +10,8 @@ from matplotlib.patches import Rectangle
from esm.utils.structure.protein_chain import ProteinChain
from esm.widgets.utils import indexing
from esm.widgets.utils.drawing.draw_protein_structure import (
draw_protein_structure,
)
from esm.widgets.utils.parsing import (
convert_range_string_to_list_of_ranges,
)
from esm.widgets.utils.drawing.draw_protein_structure import draw_protein_structure
from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges
from esm.widgets.utils.printing import wrapped_print
from esm.widgets.utils.prompting import PromptManager

View File

@@ -9,10 +9,7 @@ from matplotlib import colormaps
from PIL import Image
from esm.sdk.api import FunctionAnnotation
from esm.utils.function.interpro import (
InterPro,
InterProEntryType,
)
from esm.utils.function.interpro import InterPro, InterProEntryType
@contextmanager

View File

@@ -9,9 +9,7 @@ from esm.sdk.api import ESMProtein, FunctionAnnotation
from esm.utils import encoding
from esm.widgets.utils import indexing
from esm.widgets.utils.drawing.colors import rgba_tuple_to_hex
from esm.widgets.utils.drawing.draw_category_array import (
draw_data_array,
)
from esm.widgets.utils.drawing.draw_category_array import draw_data_array
from esm.widgets.utils.printing import wrapped_print

View File

@@ -13,13 +13,9 @@ from esm.sdk.api import (
GenerationConfig,
)
from esm.utils.constants import models
from esm.widgets.components.results_visualizer import (
create_results_visualizer,
)
from esm.widgets.components.results_visualizer import create_results_visualizer
from esm.widgets.utils.printing import wrapped_print
from esm.widgets.utils.serialization import (
create_download_results_button,
)
from esm.widgets.utils.serialization import create_download_results_button
def create_esm3_generation_launcher(

View File

@@ -1,8 +1,6 @@
from ipywidgets import widgets
from esm.widgets.components.sasa_prompt_selector import (
create_sasa_prompt_selector,
)
from esm.widgets.components.sasa_prompt_selector import create_sasa_prompt_selector
from esm.widgets.components.secondary_structure_prompt_selector import (
create_secondary_structure_prompt_selector,
)

View File

@@ -4,20 +4,12 @@ from ipywidgets import widgets
from esm.sdk.api import ESM3InferenceClient, ESMProtein
from esm.utils.constants import esm3 as C
from esm.widgets.components.function_annotator import (
create_function_annotator,
)
from esm.widgets.components.function_annotator import create_function_annotator
from esm.widgets.utils.prompting import PromptManagerCollection
from esm.widgets.utils.protein_import import ProteinImporter
from esm.widgets.views.esm3_generation_launcher import (
create_esm3_generation_launcher,
)
from esm.widgets.views.esm3_prompt_preview import (
create_esm3_prompt_preview,
)
from esm.widgets.views.esm3_prompt_selector import (
create_esm3_prompt_selector,
)
from esm.widgets.views.esm3_generation_launcher import create_esm3_generation_launcher
from esm.widgets.views.esm3_prompt_preview import create_esm3_prompt_preview
from esm.widgets.views.esm3_prompt_selector import create_esm3_prompt_selector
def create_generation_ui(

View File

@@ -6,9 +6,7 @@ from esm.sdk.api import (
ESMProteinError,
GenerationConfig,
)
from esm.widgets.components.results_visualizer import (
create_results_visualizer,
)
from esm.widgets.components.results_visualizer import create_results_visualizer
from esm.widgets.utils.printing import wrapped_print
from esm.widgets.utils.protein_import import ProteinImporter

View File

@@ -4,10 +4,7 @@ from textwrap import dedent
from ipywidgets import widgets
from esm.widgets.utils.clients import (
get_forge_client,
get_local_client,
)
from esm.widgets.utils.clients import get_forge_client, get_local_client
from esm.widgets.utils.types import ClientInitContainer

View File

@@ -6,9 +6,7 @@ from esm.sdk.api import (
ESMProteinError,
GenerationConfig,
)
from esm.widgets.components.results_visualizer import (
create_results_visualizer,
)
from esm.widgets.components.results_visualizer import create_results_visualizer
from esm.widgets.utils.printing import wrapped_print
from esm.widgets.utils.protein_import import ProteinImporter

5003
pixi.lock

File diff suppressed because it is too large Load Diff

View File

@@ -3,7 +3,7 @@ name = "esm"
version = "3.2.1"
description = "EvolutionaryScale open model repository"
readme = "README.md"
requires-python = ">=3.10"
requires-python = ">=3.12,<3.13"
license = {file = "LICENSE.txt"}
authors = [
@@ -17,7 +17,7 @@ maintainers = [
classifiers = [
"Development Status :: 3 - Alpha",
"Topic :: Scientific/Engineering :: Bio-Informatics",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.12",
]
dependencies = [
@@ -35,9 +35,25 @@ dependencies = [
"attrs",
"pandas",
"cloudpathlib",
"httpx",
"tenacity",
"zstd"
"zstd",
"ipywidgets",
"py3dmol",
"pydssp",
"boto3",
"pygtrie",
"dna_features_viewer",
"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.1/flash_attn-2.8.1+cu12torch2.7cxx11abiFALSE-cp312-cp312-linux_x86_64.whl"
]
# Pytest
[tool.pytest.ini_options]
addopts = """
--cov=esm
--cov-report term-missing:skip-covered
-n auto
--ignore=tests/oss_pytests/test_oss_client.py
"""
[tool.setuptools]
package-dir = {"" = "."}
@@ -52,7 +68,7 @@ esm = ["data/*"]
[tool.pixi.project]
channels = ["conda-forge"]
platforms = ["linux-64"]
platforms = ["linux-64", "osx-arm64"]
# These are build dependencies, to ensure pip support, keep run-time deps above in `dependencies`
[tool.pixi.dependencies]
@@ -60,6 +76,7 @@ pkg-config = "*"
cmake = "*"
pip = "*"
twine = "*"
python = "3.12.*"
[tool.pixi.pypi-dependencies]
esm = { path = ".", editable = true }
@@ -67,3 +84,61 @@ esm = { path = ".", editable = true }
[tool.pixi.tasks]
build-wheel = "python -m pip wheel --no-deps -w dist ."
upload-wheel = "python -m twine upload --repository pypi"
[tool.pixi.feature.dev.dependencies]
matplotlib = "*"
pre-commit = "*"
pytest = "*"
pytest-cov = "*"
pytest-xdist = "*"
seaborn = "*"
pyright = "==1.1.399"
[tool.pixi.feature.dev.tasks]
lint-all = "pre-commit run --all-files --show-diff-on-failure"
cov-test = "pytest -v --junitxml=pytest.xml --cov=esm"
[tool.pixi.environments]
default = {features = [], solve-group = "default"}
dev = {features = ["dev"], solve-group = "default"}
[tool.ruff]
extend-include = ["*.ipynb"]
[tool.ruff.lint]
# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`), sort imports ('I')
select = ["E4", "E7", "E9", "F", "I"]
ignore = [
# allow variable == False (tensors should do this)
"E712",
# allow assigning of lambdas
"E731",
# Allow ambiguous variables, e.g. we use O for oxygen
"E741",
# Ignore errors from jaxtyping hints
# https://docs.kidger.site/jaxtyping/faq/#flake8-or-ruff-are-throwing-an-error
"F722",
# TODO: Fix the few offenders in a follow up PR
"E721",
]
[tool.ruff.lint.isort]
split-on-trailing-comma = false
known-third-party = ["wandb"]
[tool.ruff.format]
skip-magic-trailing-comma = true
docstring-code-format = true
docstring-code-line-length = "dynamic"
[tool.isort]
known_third_party = ["wandb"]
[tool.pyright]
root = ['.']
useLibraryCodeForTypes = true
reportPrivateImportUsage = false
typeCheckingMode = "basic"
[tool.importlinter]
root_package = "esm"

15
tests/Makefile Normal file
View File

@@ -0,0 +1,15 @@
# OSS-specific variables and commands
DOCKER_TAG ?= dev
DOCKER_IMAGE_OSS=oss_pytests:${DOCKER_TAG}
build-oss-ci:
docker build -f oss_pytests/Dockerfile oss_pytests -t $(DOCKER_IMAGE_OSS)
start-docker-oss:
docker run \
--rm \
-e URL=${URL} \
-e ESM3_FORGE_TOKEN=${ESM3_FORGE_TOKEN} \
--name=$(USER)-oss_pytests \
--network=host \
${DOCKER_IMAGE_OSS}

View File

@@ -0,0 +1,19 @@
# Dockerfile.sdktest
FROM python:3.12-slim
# Install pip and basic dependencies
RUN apt-get update && apt-get install -y curl build-essential && rm -rf /var/lib/apt/lists/*
# Set working directory
WORKDIR /tests
# Copy requirements and install them (assumes esm-oss is one of them)
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy in the tests
COPY . .
# Default command (can be overridden in docker run)
CMD ["pytest", "-v", "test_oss_client.py"]

View File

@@ -0,0 +1,3 @@
esm
pytest
httpx # TODO(williamxi): Remove this after the esm repo is fixed

View File

@@ -0,0 +1,85 @@
import os
import pytest
from esm.sdk import client # pyright: ignore
from esm.sdk.api import ( # pyright: ignore
ESMProtein,
ESMProteinTensor,
ForwardAndSampleOutput,
GenerationConfig,
LogitsConfig,
LogitsOutput,
SamplingConfig,
SamplingTrackConfig,
)
from esm.sdk.forge import SequenceStructureForgeInferenceClient # pyright: ignore
API_TOKEN = os.environ.get("ESM3_FORGE_TOKEN", "")
URL = os.environ.get("URL")
@pytest.mark.sdk
def test_oss_esm3_client():
assert URL is not None
sequence = "MALWMRLLPLLALLAL___PDPAAA"
model = "esm3-small-2024-03"
esm3_client = client(model=model, url=URL, token=API_TOKEN)
protein = ESMProtein(sequence)
encoded_protein = esm3_client.encode(input=protein)
assert isinstance(encoded_protein, ESMProteinTensor)
decoded_protein = esm3_client.decode(input=encoded_protein)
assert isinstance(decoded_protein, ESMProtein)
logits_config = LogitsConfig(sequence=True, return_embeddings=True)
result = esm3_client.logits(input=encoded_protein, config=logits_config)
assert isinstance(result, LogitsOutput)
sampling_config = SamplingConfig(sequence=SamplingTrackConfig(temperature=0.1))
result = esm3_client.forward_and_sample(
input=encoded_protein, sampling_configuration=sampling_config
)
assert isinstance(result, ForwardAndSampleOutput)
generation_config = GenerationConfig(track="sequence", num_steps=4)
result = esm3_client.generate(input=protein, config=generation_config)
assert isinstance(result, ESMProtein)
@pytest.mark.sdk
def test_oss_esmc_client():
assert URL is not None
sequence = "MALWMRLLPLLALLALAVUUPDPAAA"
model = "esmc-300m-2024-12"
esmc_client = client(model=model, url=URL, token=API_TOKEN)
protein = ESMProtein(sequence)
encoded_protein = esmc_client.encode(input=protein)
assert isinstance(encoded_protein, ESMProteinTensor)
decoded_protein = esmc_client.decode(input=encoded_protein)
assert isinstance(decoded_protein, ESMProtein)
logits_config = LogitsConfig(
sequence=True, return_embeddings=True, return_hidden_states=True
)
result = esmc_client.logits(input=encoded_protein, config=logits_config)
assert isinstance(result, LogitsOutput)
@pytest.mark.sdk
def test_oss_sequence_structure_forge_inference_client():
assert URL is not None
sequence = "MALWMRLLPLLALLALAVUUPDPAAA"
model = "esm3-small-2024-03"
client = SequenceStructureForgeInferenceClient(
model=model, url=URL, token=API_TOKEN
)
encoded_protein = client.fold(sequence=sequence)
assert isinstance(encoded_protein, ESMProtein)

View File

@@ -0,0 +1,6 @@
import pytest
@pytest.mark.skip(reason="no other tests in this suite")
def test_placeholder():
pass