mirror of
https://github.com/evolutionaryscale/esm.git
synced 2026-06-04 17:14:23 +08:00
Sync over internal code to open source (#266)
Co-authored-by: Steve Chan <>
This commit is contained in:
23
.github/workflows/ci.yml
vendored
23
.github/workflows/ci.yml
vendored
@@ -18,7 +18,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
test-precommit:
|
||||
runs-on: ubuntu-22.04-16core
|
||||
runs-on: ubuntu-24.04
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
@@ -32,20 +32,11 @@ jobs:
|
||||
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
|
||||
|
||||
- name: Check formatting and typing
|
||||
run: |
|
||||
set -e
|
||||
# pyright seems to do something weird at initialization that causes it to error out
|
||||
# We can ignore the first invocation here.
|
||||
pyright esm/__init__.py || true
|
||||
pre-commit install
|
||||
env NODE_OPTIONS="--max-old-space-size=16384" pre-commit run --all-files --show-diff-on-failure
|
||||
[ -z "$(git status --porcelain)" ] && true || (echo "❌❌❌ pre-commit hook failed! A few files changed ❌❌❌]" && git status --porcelain && false)
|
||||
git reset --hard HEAD # test without the pre-commit changes
|
||||
shell: pixi run bash -e {0}
|
||||
run: pixi run lint-all
|
||||
|
||||
|
||||
test-esm:
|
||||
runs-on: ubuntu-22.04-16core
|
||||
runs-on: ubuntu-24.04
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
@@ -59,20 +50,18 @@ jobs:
|
||||
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
set -o pipefail
|
||||
pytest -v --junitxml=pytest.xml tests/ | tee pytest-coverage.txt
|
||||
shell: pixi run bash -e {0}
|
||||
run: pixi run cov-test
|
||||
|
||||
- name: Run Docker tests
|
||||
env:
|
||||
DOCKER_TAG: ${{ github.sha }}
|
||||
FORGE_URL: https://forge.evolutionaryscale.ai/
|
||||
ESM3_FORGE_TOKEN: ${{ secrets.ESM3_FORGE_TOKEN }}
|
||||
run: |
|
||||
set -e
|
||||
cd tests
|
||||
make build-oss-ci
|
||||
make start-docker-oss URL=${{ env.FORGE_URL }} DOCKER_TAG=${{ env.DOCKER_TAG }} ESM3_FORGE_TOKEN=${{ secrets.ESM3_FORGE_TOKEN }}
|
||||
make start-docker-oss URL=${{ env.FORGE_URL }} DOCKER_TAG=${{ env.DOCKER_TAG }} ESM3_FORGE_TOKEN=${{ env.ESM3_FORGE_TOKEN }}
|
||||
shell: pixi run bash -e {0}
|
||||
|
||||
- name: cleanup docker containers if they're hanging
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -2,3 +2,4 @@ esm.egg-info
|
||||
# pixi environments
|
||||
.pixi
|
||||
*.egg-info
|
||||
*.pyc
|
||||
|
||||
33
.pre-commit-config.yaml
Normal file
33
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,33 @@
|
||||
# See https://pre-commit.com for more information
|
||||
# See https://pre-commit.com/hooks.html for more hooks
|
||||
exclude: (fasta|pdb|cif|mds|json)$
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v3.2.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-yaml
|
||||
- id: check-added-large-files
|
||||
exclude: pixi.lock
|
||||
- id: check-merge-conflict
|
||||
- repo: https://github.com/seddonym/import-linter
|
||||
rev: v1.12.1
|
||||
hooks:
|
||||
- id: import-linter
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.7.3
|
||||
hooks:
|
||||
- id: ruff # linter
|
||||
args: [ --fix ]
|
||||
- id: ruff-format # formatter
|
||||
types_or: [python, jupyter]
|
||||
- repo: https://github.com/RobertCraigie/pyright-python
|
||||
rev: v1.1.399
|
||||
hooks:
|
||||
- id: pyright
|
||||
name: pyright
|
||||
entry: pyright
|
||||
language: system
|
||||
types: [python]
|
||||
pass_filenames: true # For speed, we only check the files that are changed
|
||||
26
README.md
26
README.md
@@ -9,6 +9,7 @@
|
||||
|
||||
|
||||
- [Installation ](#installation-)
|
||||
- [Available Models](#available-models-)
|
||||
- [ESM 3](#esm-3-)
|
||||
- [Quickstart for ESM3 Open](#esm3-quickstart-)
|
||||
- [ESM3 98B via Forge API](#esm3-forge)
|
||||
@@ -33,6 +34,31 @@ To get started with ESM, install the python library using pip:
|
||||
pip install esm
|
||||
```
|
||||
|
||||
## Available Models <a name="available-models"></a>
|
||||
|
||||
### ESM 3 Family
|
||||
|
||||
| Model | Model Size | Release Date | Note |
|
||||
|-------|------------|--------------|------|
|
||||
| **Flagship Models** | | | Most users will be interested in using one of these models. |
|
||||
| esm3-large-2024-03 | 98B | 2024-03 | |
|
||||
| esm3-medium-2024-08 | 7B | 2024-08 | |
|
||||
| esm3-small-2024-08 | 1.4B | 2024-08 | |
|
||||
| **Published Models** | | | These models were used to generate all of the results in the ESM3 paper and are provided to facilitate reproducibility. |
|
||||
| esm3-large-2024-03 | 98B | 2024-03 | |
|
||||
| esm3-medium-2024-03 | 7B | 2024-03 | |
|
||||
| esm3-small-2024-03 | 1.4B | 2024-03 | |
|
||||
| **Experimental Models** | | | These models are provided for early use by researchers and are still under development. |
|
||||
| esm3-medium-multimer-2024-09 | 7B | 2024-09 | |
|
||||
|
||||
### ESM C Models
|
||||
|
||||
| Model | Model Size | Number of Layers | Release Date |
|
||||
|-------|------------|------------------|--------------|
|
||||
| esmc-6b-2024-12 | 6B | 80 | 2024-12 |
|
||||
| esmc-600m-2024-12 | 600M | 36 | 2024-12 |
|
||||
| esmc-300m-2024-12 | 300M | 30 | 2024-12 |
|
||||
|
||||
## ESM 3 <a name="esm3"></a>
|
||||
|
||||
[ESM3](https://www.evolutionaryscale.ai/papers/esm3-simulating-500-million-years-of-evolution-with-a-language-model) is a frontier generative model for biology, able to jointly reason across three fundamental biological properties of proteins: sequence, structure, and function. These three data modalities are represented as tracks of discrete tokens at the input and output of ESM3. You can present the model with a combination of partial inputs across the tracks, and ESM3 will provide output predictions for all the tracks.
|
||||
|
||||
@@ -38,6 +38,7 @@
|
||||
"\n",
|
||||
"!pip install py3Dmol\n",
|
||||
"import py3Dmol\n",
|
||||
"\n",
|
||||
"from esm.models.esm3 import ESM3\n",
|
||||
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
|
||||
"from esm.utils.structure.protein_chain import ProteinChain"
|
||||
|
||||
@@ -13,9 +13,7 @@ from esm.tokenization import get_esm3_model_tokenizers
|
||||
from esm.tokenization.function_tokenizer import (
|
||||
InterProQuantizedTokenizer as EsmFunctionTokenizer,
|
||||
)
|
||||
from esm.tokenization.sequence_tokenizer import (
|
||||
EsmSequenceTokenizer,
|
||||
)
|
||||
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
|
||||
from esm.utils.structure.protein_chain import ProteinChain
|
||||
from esm.utils.types import FunctionAnnotation
|
||||
|
||||
|
||||
@@ -72,6 +72,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from biotite.database import rcsb\n",
|
||||
"\n",
|
||||
"from esm.sdk.api import ESMProtein\n",
|
||||
"from esm.utils.structure.protein_chain import ProteinChain\n",
|
||||
"from esm.utils.types import FunctionAnnotation\n",
|
||||
@@ -496,9 +497,10 @@
|
||||
"# Functions for visualizing InterPro function annotations\n",
|
||||
"\n",
|
||||
"from dna_features_viewer import GraphicFeature, GraphicRecord\n",
|
||||
"from esm.utils.function.interpro import InterPro, InterProEntryType\n",
|
||||
"from matplotlib import colormaps\n",
|
||||
"\n",
|
||||
"from esm.utils.function.interpro import InterPro, InterProEntryType\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def visualize_function_annotations(\n",
|
||||
" annotations: list[FunctionAnnotation],\n",
|
||||
|
||||
@@ -64,6 +64,7 @@
|
||||
"import matplotlib.pyplot as pl\n",
|
||||
"import py3Dmol\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"from esm.sdk import client\n",
|
||||
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
|
||||
"from esm.utils.structure.protein_chain import ProteinChain"
|
||||
|
||||
@@ -36,6 +36,7 @@
|
||||
"\n",
|
||||
"!pip install py3Dmol\n",
|
||||
"import py3Dmol\n",
|
||||
"\n",
|
||||
"from esm.sdk import client\n",
|
||||
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
|
||||
"from esm.utils.structure.protein_chain import ProteinChain"
|
||||
|
||||
@@ -14,13 +14,13 @@
|
||||
"3. Minimize a biophysical energy function\n",
|
||||
"4. Use experimental screening data to guide designs with a regression model\n",
|
||||
"\n",
|
||||
"As long as your scoring function takes a protein as input and outputs a single score, you can use it to guide designs. To accomplish this, we use an implementation of derivative-free guidance inspired by Soft Value-Based Decoding described in [Li, et al 2024](https://arxiv.org/abs/2408.08252).\n",
|
||||
"As long as your scoring function takes a protein as input and outputs a single score, you can use it to guide designs. To accomplish this, we use an implementation of derivative-free guidance inspired by Soft Value-Based Decoding described in [Li, et al 2024](https://arxiv.org/abs/2408.08252) and constrained optimization using the Modified Differential Method of Multipliers from [Platt & Barr 1987](https://proceedings.neurips.cc/paper_files/paper/1987/file/a1126573153ad7e9f44ba80e99316482-Paper.pdf)\n",
|
||||
"\n",
|
||||
"In this notebook we will walk through a few examples to illustrate how to use guided generation. \n",
|
||||
"\n",
|
||||
"1. Guide towards high pTM for improved generation quality\n",
|
||||
"2. Generate a protein with no cysteine (C) residues\n",
|
||||
"3. Maximize protein globularity by minimizing the radius of gyration\n",
|
||||
"3. Maximize protein globularity by minimizing the radius of gyration, while keeping pTM high\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
@@ -49,6 +49,7 @@
|
||||
"source": [
|
||||
"import biotite.structure as bs\n",
|
||||
"import py3Dmol\n",
|
||||
"\n",
|
||||
"from esm.sdk.api import ESMProtein, GenerationConfig\n",
|
||||
"from esm.sdk.experimental import ESM3GuidedDecoding, GuidedDecodingScoringFunction"
|
||||
]
|
||||
@@ -269,6 +270,11 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Start from a fully masked protein\n",
|
||||
"PROTEIN_LENGTH = 256\n",
|
||||
"starting_protein = ESMProtein(sequence=\"_\" * PROTEIN_LENGTH)\n",
|
||||
"\n",
|
||||
"# Call guided_generate\n",
|
||||
"no_cysteine_protein = no_cysteine_guided_decoding.guided_generate(\n",
|
||||
" protein=starting_protein,\n",
|
||||
" num_decoding_steps=len(starting_protein) // 8,\n",
|
||||
@@ -302,7 +308,20 @@
|
||||
"source": [
|
||||
"## Maximize Globularity\n",
|
||||
"\n",
|
||||
"We use the radius of gyration as a proxy to maximize globularity, we also encourage generations to have high pTM"
|
||||
"We use the radius of gyration as a proxy to maximize globularity, and we will also encourage generations to have high pTM by using constraints"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from esm.sdk.experimental import (\n",
|
||||
" ConstraintType,\n",
|
||||
" ESM3GuidedDecodingWithConstraints,\n",
|
||||
" GenerationConstraint,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -313,12 +332,11 @@
|
||||
"source": [
|
||||
"class RadiousOfGyrationScoringFunction(GuidedDecodingScoringFunction):\n",
|
||||
" def __call__(self, protein: ESMProtein) -> float:\n",
|
||||
" # Use the negative radius of gyration as the score to maximize\n",
|
||||
" score = -1 * self.radius_of_gyration(protein)\n",
|
||||
"\n",
|
||||
" assert protein.ptm is not None, \"Protein must have pTM scores to be scored\"\n",
|
||||
" if protein.ptm < 0.5:\n",
|
||||
" # Penalize proteins with low pTM scores\n",
|
||||
" score = score * 2\n",
|
||||
" # Re-scale the score to be in a similar magnitude as pTM\n",
|
||||
" score = score / 100.0\n",
|
||||
"\n",
|
||||
" return score\n",
|
||||
"\n",
|
||||
@@ -335,8 +353,19 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"radius_guided_decoding = ESM3GuidedDecoding(\n",
|
||||
" client=model, scoring_function=RadiousOfGyrationScoringFunction()\n",
|
||||
"# Constrain generation to have pTM > 0.75\n",
|
||||
"ptm_constraint = GenerationConstraint(\n",
|
||||
" scoring_function=PTMScoringFunction(),\n",
|
||||
" constraint_type=ConstraintType.GREATER_EQUAL,\n",
|
||||
" value=0.75,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"radius_guided_decoding = ESM3GuidedDecodingWithConstraints(\n",
|
||||
" client=model,\n",
|
||||
" scoring_function=RadiousOfGyrationScoringFunction(),\n",
|
||||
" constraints=[ptm_constraint], # Add list of constraints\n",
|
||||
" damping=1.0, # Damping factor for the MMDM algorithm\n",
|
||||
" learning_rate=10.0, # Learning rate for the MMDM algorithm\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@@ -346,6 +375,11 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Start from a fully masked protein\n",
|
||||
"PROTEIN_LENGTH = 256\n",
|
||||
"starting_protein = ESMProtein(sequence=\"_\" * PROTEIN_LENGTH)\n",
|
||||
"\n",
|
||||
"# Call guided_generate\n",
|
||||
"radius_guided_protein = radius_guided_decoding.guided_generate(\n",
|
||||
" protein=starting_protein,\n",
|
||||
" num_decoding_steps=len(starting_protein) // 8,\n",
|
||||
@@ -359,11 +393,32 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Visualize the trajectory of the constrained generation\n",
|
||||
"radius_guided_decoding.visualize_latest_trajectory()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Visualize the generated protein\n",
|
||||
"view = py3Dmol.view(width=800, height=400)\n",
|
||||
"view.addModel(radius_guided_protein.to_pdb_string(), \"pdb\")\n",
|
||||
"view.setStyle({\"cartoon\": {\"color\": \"spectrum\"}})\n",
|
||||
"view.zoomTo()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Check pTM\n",
|
||||
"radius_guided_protein.ptm"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
@@ -1,2 +1 @@
|
||||
__version__ = "3.2.1"
|
||||
|
||||
|
||||
@@ -5,15 +5,12 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from esm.layers.rotary import (
|
||||
RotaryEmbedding,
|
||||
TritonRotaryEmbedding,
|
||||
)
|
||||
from esm.layers.rotary import RotaryEmbedding, TritonRotaryEmbedding
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_qkvpacked_func # type:ignore
|
||||
except ImportError:
|
||||
flash_attn_varlen_func = None
|
||||
from flash_attn import flash_attn_varlen_qkvpacked_func
|
||||
except (ImportError, RuntimeError):
|
||||
flash_attn_varlen_qkvpacked_func = None
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
@@ -117,7 +114,7 @@ class FlashMultiHeadAttention(MultiHeadAttention):
|
||||
)
|
||||
qkv_N3HD = self.rotary(qkv_N3HD, cu_seqlens, max_seqlen)
|
||||
|
||||
context_NHD = flash_attn_varlen_qkvpacked_func(
|
||||
context_NHD = flash_attn_varlen_qkvpacked_func( # type: ignore
|
||||
qkv_N3HD, cu_seqlens, max_seqlen, softmax_scale=self.d_head**-0.5
|
||||
)
|
||||
context_ND = einops.rearrange(context_NHD, "n h d -> n (h d)")
|
||||
|
||||
@@ -2,13 +2,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from esm.layers.attention import (
|
||||
FlashMultiHeadAttention,
|
||||
MultiHeadAttention,
|
||||
)
|
||||
from esm.layers.geom_attention import (
|
||||
GeometricReasoningOriginalImpl,
|
||||
)
|
||||
from esm.layers.attention import FlashMultiHeadAttention, MultiHeadAttention
|
||||
from esm.layers.geom_attention import GeometricReasoningOriginalImpl
|
||||
from esm.utils.structure.affine3d import Affine3D
|
||||
|
||||
|
||||
|
||||
@@ -2,10 +2,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from esm.utils.constants.physics import BB_COORDINATES
|
||||
from esm.utils.structure.affine3d import (
|
||||
Affine3D,
|
||||
RotationMatrix,
|
||||
)
|
||||
from esm.utils.structure.affine3d import Affine3D, RotationMatrix
|
||||
|
||||
|
||||
class Dim6RotStructureHead(nn.Module):
|
||||
|
||||
@@ -13,10 +13,7 @@ from attr import dataclass
|
||||
from esm.layers.regression_head import RegressionHead
|
||||
from esm.layers.transformer_stack import TransformerStack
|
||||
from esm.models.function_decoder import FunctionTokenDecoder
|
||||
from esm.models.vqvae import (
|
||||
StructureTokenDecoder,
|
||||
StructureTokenEncoder,
|
||||
)
|
||||
from esm.models.vqvae import StructureTokenDecoder, StructureTokenEncoder
|
||||
from esm.sdk.api import (
|
||||
ESM3InferenceClient,
|
||||
ESMProtein,
|
||||
@@ -32,10 +29,7 @@ from esm.sdk.api import (
|
||||
from esm.tokenization import TokenizerCollectionProtocol
|
||||
from esm.utils import encoding
|
||||
from esm.utils.constants import esm3 as C
|
||||
from esm.utils.constants.models import (
|
||||
ESM3_OPEN_SMALL,
|
||||
normalize_model_name,
|
||||
)
|
||||
from esm.utils.constants.models import ESM3_OPEN_SMALL, normalize_model_name
|
||||
from esm.utils.decoding import decode_protein_tensor
|
||||
from esm.utils.generation import (
|
||||
_batch_forward,
|
||||
@@ -50,9 +44,7 @@ from esm.utils.sampling import (
|
||||
get_default_sampling_config,
|
||||
validate_sampling_config,
|
||||
)
|
||||
from esm.utils.structure.affine3d import (
|
||||
build_affine3d_from_coordinates,
|
||||
)
|
||||
from esm.utils.structure.affine3d import build_affine3d_from_coordinates
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -12,9 +12,7 @@ from cloudpathlib import AnyPath
|
||||
|
||||
from esm.layers.regression_head import RegressionHead
|
||||
from esm.layers.transformer_stack import TransformerStack
|
||||
from esm.tokenization.function_tokenizer import (
|
||||
InterProQuantizedTokenizer,
|
||||
)
|
||||
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
|
||||
from esm.utils.constants import esm3 as C
|
||||
from esm.utils.misc import merge_annotations, merge_ranges
|
||||
from esm.utils.types import FunctionAnnotation
|
||||
|
||||
@@ -7,10 +7,7 @@ from esm.layers.structure_proj import Dim6RotStructureHead
|
||||
from esm.layers.transformer_stack import TransformerStack
|
||||
from esm.utils.constants import esm3 as C
|
||||
from esm.utils.misc import knn_graph
|
||||
from esm.utils.structure.affine3d import (
|
||||
Affine3D,
|
||||
build_affine3d_from_coordinates,
|
||||
)
|
||||
from esm.utils.structure.affine3d import Affine3D, build_affine3d_from_coordinates
|
||||
from esm.utils.structure.predicted_aligned_error import (
|
||||
compute_predicted_aligned_error,
|
||||
compute_tm,
|
||||
|
||||
@@ -6,14 +6,8 @@ import torch.nn as nn
|
||||
from esm.models.esm3 import ESM3
|
||||
from esm.models.esmc import ESMC
|
||||
from esm.models.function_decoder import FunctionTokenDecoder
|
||||
from esm.models.vqvae import (
|
||||
StructureTokenDecoder,
|
||||
StructureTokenEncoder,
|
||||
)
|
||||
from esm.tokenization import (
|
||||
get_esm3_model_tokenizers,
|
||||
get_esmc_model_tokenizers,
|
||||
)
|
||||
from esm.models.vqvae import StructureTokenDecoder, StructureTokenEncoder
|
||||
from esm.tokenization import get_esm3_model_tokenizers, get_esmc_model_tokenizers
|
||||
from esm.utils.constants.esm3 import data_root
|
||||
from esm.utils.constants.models import (
|
||||
ESM3_FUNCTION_DECODER_V0,
|
||||
|
||||
@@ -2,27 +2,19 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from copy import deepcopy
|
||||
from typing import List, Sequence
|
||||
from typing import Sequence
|
||||
|
||||
import attr
|
||||
import torch
|
||||
from attr import asdict, define
|
||||
|
||||
import esm.utils.constants.api as C
|
||||
from esm.tokenization import (
|
||||
TokenizerCollectionProtocol,
|
||||
get_esm3_model_tokenizers,
|
||||
)
|
||||
from esm.tokenization import TokenizerCollectionProtocol, get_esm3_model_tokenizers
|
||||
from esm.utils import encoding
|
||||
from esm.utils.constants.models import ESM3_OPEN_SMALL
|
||||
from esm.utils.misc import (
|
||||
get_chainbreak_boundaries_from_sequence,
|
||||
)
|
||||
from esm.utils.misc import get_chainbreak_boundaries_from_sequence
|
||||
from esm.utils.structure.protein_chain import ProteinChain
|
||||
from esm.utils.structure.protein_complex import (
|
||||
SINGLE_LETTER_CHAIN_IDS,
|
||||
ProteinComplex,
|
||||
)
|
||||
from esm.utils.structure.protein_complex import SINGLE_LETTER_CHAIN_IDS, ProteinComplex
|
||||
from esm.utils.types import FunctionAnnotation, PathOrBuffer
|
||||
|
||||
|
||||
@@ -43,7 +35,6 @@ class ESMProtein(ProteinType):
|
||||
plddt: torch.Tensor | None = None
|
||||
ptm: torch.Tensor | None = None
|
||||
|
||||
|
||||
# When calling EvolutionaryScale API, use this flag to disclose any
|
||||
# sequences that may potentially have concerns.
|
||||
# Such sequences may not go through standard safety filter for approved users.
|
||||
@@ -79,12 +70,9 @@ class ESMProtein(ProteinType):
|
||||
def from_protein_chain(
|
||||
cls, protein_chain: ProteinChain, with_annotations: bool = False
|
||||
) -> ESMProtein:
|
||||
# By default, we don't annotate with DSSP / SASA, which are expensive.
|
||||
# If mkdssp is installed, we can annotate with a flag.
|
||||
if with_annotations:
|
||||
return ESMProtein(
|
||||
sequence=protein_chain.sequence,
|
||||
secondary_structure=protein_chain.dssp().tolist(),
|
||||
sasa=protein_chain.sasa().tolist(),
|
||||
function_annotations=None,
|
||||
coordinates=torch.tensor(protein_chain.atom37_positions),
|
||||
@@ -123,7 +111,8 @@ class ESMProtein(ProteinType):
|
||||
protein_complex.to_pdb(pdb_path)
|
||||
|
||||
def to_pdb_string(self) -> str:
|
||||
protein_chain = self.to_protein_chain()
|
||||
# Note: This was modified to match .to_pdb() behavior. We can revisit this at some point
|
||||
protein_chain = self.to_protein_complex().infer_oxygen()
|
||||
return protein_chain.to_pdb_string()
|
||||
|
||||
def to_protein_chain(self) -> ProteinChain:
|
||||
@@ -172,6 +161,7 @@ class ESMProtein(ProteinType):
|
||||
if gt_chains is not None
|
||||
else SINGLE_LETTER_CHAIN_IDS[i],
|
||||
entity_id=gt_chains[i].entity_id if gt_chains is not None else None,
|
||||
confidence=self.plddt[start:end] if self.plddt is not None else None,
|
||||
)
|
||||
pred_chains.append(pred_chain)
|
||||
return ProteinComplex.from_chains(pred_chains)
|
||||
@@ -321,8 +311,6 @@ class InverseFoldingConfig:
|
||||
temperature: float = 1.0
|
||||
|
||||
|
||||
|
||||
|
||||
## Low Level Endpoint Types
|
||||
@define
|
||||
class SamplingTrackConfig:
|
||||
@@ -382,22 +370,23 @@ class LogitsConfig:
|
||||
# Embeddings.
|
||||
return_embeddings: bool = False
|
||||
return_hidden_states: bool = False
|
||||
return_mean_embedding: bool = False
|
||||
return_mean_hidden_states: bool = False
|
||||
ith_hidden_layer: int = -1
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@define
|
||||
class LogitsOutput:
|
||||
logits: ForwardTrackData | None = None
|
||||
embeddings: torch.Tensor | None = None
|
||||
mean_embedding: torch.Tensor | None = None
|
||||
|
||||
# Residue annotations is multi-hot, so deserves special treatment
|
||||
# It's not a categorical distribution, but instead a bernoulli, so
|
||||
# softmax across the last dimension is _wrong_
|
||||
residue_annotation_logits: torch.Tensor | None = None
|
||||
hidden_states: torch.Tensor | None = None
|
||||
mean_hidden_state: torch.Tensor | None = None
|
||||
|
||||
|
||||
@define
|
||||
|
||||
@@ -70,11 +70,13 @@ class _BaseForgeInferenceClient:
|
||||
def prepare_request(
|
||||
self,
|
||||
request: dict[str, Any],
|
||||
potential_sequence_of_concern: bool = False,
|
||||
potential_sequence_of_concern: bool | None = None,
|
||||
return_bytes: bool = False,
|
||||
headers: dict[str, str] = {},
|
||||
) -> tuple[dict[str, Any], dict[str, str]]:
|
||||
request["potential_sequence_of_concern"] = potential_sequence_of_concern
|
||||
if potential_sequence_of_concern is not None:
|
||||
request["potential_sequence_of_concern"] = potential_sequence_of_concern
|
||||
|
||||
headers = {**self.headers, **headers}
|
||||
if return_bytes:
|
||||
headers["return-bytes"] = "true"
|
||||
@@ -103,42 +105,58 @@ class _BaseForgeInferenceClient:
|
||||
self,
|
||||
endpoint,
|
||||
request,
|
||||
potential_sequence_of_concern: bool = False,
|
||||
potential_sequence_of_concern: bool | None = None,
|
||||
params: dict[str, Any] = {},
|
||||
headers: dict[str, str] = {},
|
||||
return_bytes: bool = False,
|
||||
):
|
||||
request, headers = self.prepare_request(
|
||||
request, potential_sequence_of_concern, return_bytes, headers
|
||||
)
|
||||
response = await self.async_client.post(
|
||||
url=urljoin(self.url, f"/api/v1/{endpoint}"),
|
||||
json=request,
|
||||
params=params,
|
||||
headers=headers,
|
||||
timeout=self.request_timeout,
|
||||
)
|
||||
data = self.prepare_data(response, endpoint)
|
||||
return data
|
||||
try:
|
||||
request, headers = self.prepare_request(
|
||||
request, potential_sequence_of_concern, return_bytes, headers
|
||||
)
|
||||
response = await self.async_client.post(
|
||||
url=urljoin(self.url, f"/api/v1/{endpoint}"),
|
||||
json=request,
|
||||
params=params,
|
||||
headers=headers,
|
||||
timeout=self.request_timeout,
|
||||
)
|
||||
data = self.prepare_data(response, endpoint)
|
||||
return data
|
||||
except ESMProteinError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise ESMProteinError(
|
||||
error_code=500,
|
||||
error_msg=f"Failed to submit request to {endpoint}. Error: {e}",
|
||||
)
|
||||
|
||||
def _post(
|
||||
self,
|
||||
endpoint,
|
||||
request,
|
||||
potential_sequence_of_concern: bool = False,
|
||||
potential_sequence_of_concern: bool | None = None,
|
||||
params: dict[str, Any] = {},
|
||||
headers: dict[str, str] = {},
|
||||
return_bytes: bool = False,
|
||||
):
|
||||
request, headers = self.prepare_request(
|
||||
request, potential_sequence_of_concern, return_bytes, headers
|
||||
)
|
||||
response = self.client.post(
|
||||
url=urljoin(self.url, f"/api/v1/{endpoint}"),
|
||||
json=request,
|
||||
params=params,
|
||||
headers=headers,
|
||||
timeout=self.request_timeout,
|
||||
)
|
||||
data = self.prepare_data(response, endpoint)
|
||||
return data
|
||||
try:
|
||||
request, headers = self.prepare_request(
|
||||
request, potential_sequence_of_concern, return_bytes, headers
|
||||
)
|
||||
response = self.client.post(
|
||||
url=urljoin(self.url, f"/api/v1/{endpoint}"),
|
||||
json=request,
|
||||
params=params,
|
||||
headers=headers,
|
||||
timeout=self.request_timeout,
|
||||
)
|
||||
data = self.prepare_data(response, endpoint)
|
||||
return data
|
||||
except ESMProteinError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise ESMProteinError(
|
||||
error_code=500,
|
||||
error_msg=f"Failed to submit request to {endpoint}. Error: {e}",
|
||||
)
|
||||
|
||||
14
esm/sdk/experimental/__init__.py
Normal file
14
esm/sdk/experimental/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from .constrained_generation import (
|
||||
ConstraintType,
|
||||
ESM3GuidedDecodingWithConstraints,
|
||||
GenerationConstraint,
|
||||
)
|
||||
from .guided_generation import ESM3GuidedDecoding, GuidedDecodingScoringFunction
|
||||
|
||||
__all__ = [
|
||||
"ConstraintType",
|
||||
"ESM3GuidedDecodingWithConstraints",
|
||||
"GenerationConstraint",
|
||||
"ESM3GuidedDecoding",
|
||||
"GuidedDecodingScoringFunction",
|
||||
]
|
||||
324
esm/sdk/experimental/constrained_generation.py
Normal file
324
esm/sdk/experimental/constrained_generation.py
Normal file
@@ -0,0 +1,324 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib.collections import LineCollection
|
||||
from matplotlib.colors import Normalize
|
||||
from tqdm import tqdm
|
||||
|
||||
from esm.sdk import batch_executor
|
||||
from esm.sdk.api import (
|
||||
ESM3InferenceClient,
|
||||
ESMProtein,
|
||||
ESMProteinError,
|
||||
ESMProteinTensor,
|
||||
SamplingConfig,
|
||||
SamplingTrackConfig,
|
||||
)
|
||||
from esm.sdk.experimental.guided_generation import (
|
||||
ESM3GuidedDecoding,
|
||||
GuidedDecodingScoringFunction,
|
||||
)
|
||||
|
||||
|
||||
class ConstraintType(Enum):
|
||||
GREATER_EQUAL = "greater_equal" # f(x) ≥ threshold
|
||||
LESS_EQUAL = "less_equal" # f(x) ≤ threshold
|
||||
EQUAL = "equal" # f(x) = threshold
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class GenerationConstraint:
|
||||
"""
|
||||
A single inequality or equality constraint.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
scoring_function
|
||||
Maps a protein ➜ real value (e.g. pTM, length, …).
|
||||
value
|
||||
Target value for inequality or equality constraint.
|
||||
constraint_type
|
||||
Type of constraint to apply.
|
||||
- GREATER_EQUAL: f(x) ≥ value
|
||||
- LESS_EQUAL: f(x) ≤ value
|
||||
- EQUAL: f(x) = value (equality)
|
||||
"""
|
||||
|
||||
scoring_function: GuidedDecodingScoringFunction
|
||||
value: float
|
||||
constraint_type: ConstraintType = ConstraintType.GREATER_EQUAL
|
||||
lambda_: float = field(default=0.0, init=False) # dual variable (MDMM)
|
||||
|
||||
def g(self, x: float) -> float:
|
||||
"""
|
||||
Canonical form:
|
||||
|
||||
• inequalities → g(x) ≤ 0
|
||||
• equalities → h(x) (we still return a scalar, no ≤)
|
||||
"""
|
||||
if self.constraint_type is ConstraintType.GREATER_EQUAL:
|
||||
return self.value - x
|
||||
if self.constraint_type is ConstraintType.LESS_EQUAL:
|
||||
return x - self.value
|
||||
# equality: h(x) = x - value (we will *not* project λ)
|
||||
return x - self.value
|
||||
|
||||
def update_lambda(self, g: float, eta: float, gamma: float) -> None:
|
||||
"""
|
||||
Update the dual variable λ according to the MDMM update rule.
|
||||
"""
|
||||
if self.constraint_type is ConstraintType.EQUAL:
|
||||
self.lambda_ += eta * g # no projection for equality constraints
|
||||
else:
|
||||
self.lambda_ = max(0.0, self.lambda_ + eta * g)
|
||||
|
||||
def copy(self) -> GenerationConstraint:
|
||||
"""
|
||||
Create a copy of this constraint.
|
||||
"""
|
||||
c = GenerationConstraint(
|
||||
scoring_function=self.scoring_function,
|
||||
value=self.value,
|
||||
constraint_type=self.constraint_type,
|
||||
)
|
||||
c.lambda_ = self.lambda_ # copy the dual variable
|
||||
return c
|
||||
|
||||
|
||||
class ESM3GuidedDecodingWithConstraints(ESM3GuidedDecoding):
|
||||
"""
|
||||
Derivative-free guided decoding with constraints.
|
||||
|
||||
Uses the Modified Differential Method of Multipliers (MDMM) to
|
||||
guarantee convergence to the constrained optimum without
|
||||
hand-tuning penalty weights.
|
||||
|
||||
References:
|
||||
[1] Platt, John, and Alan Barr. "Constrained differential optimization." Neural Information Processing Systems. 1987.
|
||||
[2] https://www.engraved.blog/how-we-can-make-machine-learning-algorithms-tunable/
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: ESM3InferenceClient,
|
||||
scoring_function: GuidedDecodingScoringFunction,
|
||||
constraints: GenerationConstraint | list[GenerationConstraint],
|
||||
*,
|
||||
damping: float = 10.0,
|
||||
learning_rate: float = 1.0,
|
||||
):
|
||||
super().__init__(client, scoring_function)
|
||||
|
||||
if isinstance(constraints, GenerationConstraint):
|
||||
constraints = [constraints]
|
||||
|
||||
self.constraints = [c.copy() for c in constraints]
|
||||
self.gamma = float(damping)
|
||||
self.eta = float(learning_rate)
|
||||
|
||||
self.recorder: TrajectoryRecorder | None = None
|
||||
|
||||
def guided_generate(
|
||||
self,
|
||||
protein: ESMProtein,
|
||||
num_decoding_steps: int,
|
||||
num_samples_per_step: int,
|
||||
denoised_prediction_temperature: float = 0.0,
|
||||
track: str = "sequence",
|
||||
verbose: bool = True,
|
||||
) -> ESMProtein:
|
||||
# Reset the trajectory recorder
|
||||
self.recorder = TrajectoryRecorder()
|
||||
|
||||
protein_tensor = self.client.encode(protein)
|
||||
assert not isinstance(protein_tensor, ESMProteinError)
|
||||
|
||||
if track == "structure":
|
||||
protein_tensor = self.maybe_add_default_structure_tokens(protein_tensor)
|
||||
|
||||
n_masked = self.get_number_of_masked_positions(protein_tensor, track=track)
|
||||
n_unmask = n_masked // num_decoding_steps
|
||||
|
||||
best_reward = float("-inf")
|
||||
|
||||
if verbose:
|
||||
pbar = tqdm(range(num_decoding_steps), desc="S=-inf λ=0.00")
|
||||
else:
|
||||
pbar = range(num_decoding_steps)
|
||||
|
||||
for step in pbar:
|
||||
# Last iteration: unmask whatever is left
|
||||
if step == num_decoding_steps - 1:
|
||||
n_unmask = self.get_number_of_masked_positions(
|
||||
protein_tensor, track=track
|
||||
)
|
||||
|
||||
# ---------- propose & evaluate M samples (parallel-safe) ---- #
|
||||
def _propose_and_eval(pt: ESMProteinTensor):
|
||||
new_pt = self.randomly_unmask_positions(pt, n_unmask, track=track)
|
||||
reward, g_val, raw_vals = self._score_and_constraints(
|
||||
new_pt, denoised_prediction_temperature
|
||||
)
|
||||
return new_pt, reward, g_val, raw_vals
|
||||
|
||||
if self._use_batch_executor:
|
||||
with batch_executor(show_progress=False) as ex:
|
||||
results = ex.execute_batch(
|
||||
user_func=_propose_and_eval,
|
||||
pt=[protein_tensor] * num_samples_per_step,
|
||||
)
|
||||
|
||||
if isinstance(results, Exception):
|
||||
raise results
|
||||
samples, rewards, gh_lists, val_lists = zip(*results) # type: ignore
|
||||
else:
|
||||
samples, rewards, gh_lists, val_lists = [], [], [], []
|
||||
for _ in range(num_samples_per_step):
|
||||
s, r, g, c = _propose_and_eval(protein_tensor)
|
||||
samples.append(s)
|
||||
rewards.append(r)
|
||||
gh_lists.append(g)
|
||||
val_lists.append(c)
|
||||
|
||||
# -------- compute MDMM lagrangian for each sample -----------
|
||||
lags = [self._lagrangian(r, ghs) for r, ghs in zip(rewards, gh_lists)]
|
||||
|
||||
best_idx = int(torch.tensor(lags).argmin())
|
||||
protein_tensor = samples[best_idx]
|
||||
best_reward = rewards[best_idx]
|
||||
best_g_vals = gh_lists[best_idx]
|
||||
|
||||
# -------- dual updates (MDMM) -----------------
|
||||
for g, c in zip(best_g_vals, self.constraints):
|
||||
c.update_lambda(g, self.eta, self.gamma)
|
||||
|
||||
self.recorder.log(
|
||||
step=step,
|
||||
reward=best_reward,
|
||||
g_list=best_g_vals,
|
||||
lambda_list=[c.lambda_ for c in self.constraints],
|
||||
)
|
||||
|
||||
if verbose and isinstance(pbar, tqdm):
|
||||
lam_str = ", ".join(
|
||||
f"λ_{i}={c.lambda_:.2f}" for i, c in enumerate(self.constraints)
|
||||
)
|
||||
pbar.set_description(f"S={best_reward:+.3f} {lam_str}")
|
||||
|
||||
final = self.client.forward_and_sample(
|
||||
protein_tensor,
|
||||
sampling_configuration=SamplingConfig(
|
||||
sequence=SamplingTrackConfig(temperature=0.0),
|
||||
structure=SamplingTrackConfig(temperature=0.0),
|
||||
),
|
||||
)
|
||||
assert not isinstance(final, ESMProteinError)
|
||||
decoded = self.client.decode(final.protein_tensor)
|
||||
assert not isinstance(decoded, ESMProteinError)
|
||||
return decoded
|
||||
|
||||
def visualize_latest_trajectory(
|
||||
self, constraint_idx: int = 0, cmap: str = "viridis"
|
||||
) -> None:
|
||||
"""
|
||||
Visualise the trajectory of the latest optimisation run.
|
||||
If you optimise multiple constraints, pick which one to plot via `constraint_idx`.
|
||||
"""
|
||||
if not self.recorder:
|
||||
raise ValueError("No trajectory recorder available.")
|
||||
|
||||
steps, g_vals, rewards = self.recorder.as_arrays(constraint_idx)
|
||||
self.recorder.plot_line(constraint_idx=constraint_idx, cmap=cmap)
|
||||
|
||||
def _score_and_constraints(
|
||||
self, pt: ESMProteinTensor, temp: float
|
||||
) -> tuple[float, list[float], list[float]]:
|
||||
protein = self.predict_denoised(pt, temperature=temp)
|
||||
reward = self.scoring_function(protein)
|
||||
vals, ghs = [], []
|
||||
for c in self.constraints:
|
||||
val = c.scoring_function(protein)
|
||||
vals.append(val)
|
||||
ghs.append(c.g(val))
|
||||
return reward, ghs, vals
|
||||
|
||||
def _lagrangian(self, reward: float, g_vals: list[float]) -> float:
|
||||
"""
|
||||
MDMM L(x, λ) = -reward + Σ_i (λ_i - γ g_i) * g_i
|
||||
(reward is to be maximised ⇒ we minimise -reward)
|
||||
"""
|
||||
lag = -reward
|
||||
for g, c in zip(g_vals, self.constraints):
|
||||
lag += (c.lambda_ - self.gamma * g) * g
|
||||
return lag
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrajectoryRecorder:
|
||||
steps: List[int] = field(default_factory=list)
|
||||
rewards: List[float] = field(default_factory=list)
|
||||
g_vals: List[List[float]] = field(
|
||||
default_factory=list
|
||||
) # each step → list of constraints
|
||||
lambdas: List[List[float]] = field(default_factory=list) # each step → list of λ s
|
||||
|
||||
def log(
|
||||
self, step: int, reward: float, g_list: list[float], lambda_list: list[float]
|
||||
) -> None:
|
||||
"""Append one optimisation step to the trajectory."""
|
||||
self.steps.append(step)
|
||||
self.rewards.append(reward)
|
||||
self.g_vals.append(list(g_list))
|
||||
self.lambdas.append(list(lambda_list))
|
||||
|
||||
def as_arrays(self, constraint_idx: int = 0):
|
||||
"""
|
||||
Return numpy arrays suitable for plotting.
|
||||
If you optimise multiple constraints, pick which one to plot via `constraint_idx`.
|
||||
"""
|
||||
return (
|
||||
np.asarray(self.steps),
|
||||
np.asarray([g[constraint_idx] for g in self.g_vals]),
|
||||
np.asarray(self.rewards),
|
||||
)
|
||||
|
||||
def plot_line(self, constraint_idx: int = 0, cmap: str = "viridis"):
|
||||
"""
|
||||
Continuous line with markers and a colour-gradient that follows the optimisation step.
|
||||
"""
|
||||
steps, x_vals, y_vals = self.as_arrays(constraint_idx)
|
||||
|
||||
# build coloured line segments
|
||||
points = np.column_stack([x_vals, y_vals])
|
||||
segments = np.concatenate([points[:-1, None, :], points[1:, None, :]], axis=1)
|
||||
|
||||
norm = Normalize(vmin=steps.min(), vmax=steps.max())
|
||||
lc = LineCollection(segments, cmap=cmap, norm=norm, linewidth=2) # type: ignore
|
||||
lc.set_array(steps)
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.add_collection(lc)
|
||||
ax.scatter(
|
||||
x_vals,
|
||||
y_vals,
|
||||
c=steps,
|
||||
cmap=cmap,
|
||||
norm=norm,
|
||||
marker="o",
|
||||
edgecolor="k",
|
||||
zorder=3,
|
||||
)
|
||||
|
||||
ax.axvline(0, linestyle="--", color="grey")
|
||||
ax.set_xlabel("constraint value g(x) (≤ 0 is feasible)")
|
||||
ax.set_ylabel("reward R(x)")
|
||||
ax.set_title("Trajectory in constraint–reward space")
|
||||
plt.colorbar(lc, label="optimisation step")
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
@@ -44,7 +44,7 @@ class ESM3GuidedDecoding:
|
||||
self._use_batch_executor = True
|
||||
else:
|
||||
raise ValueError(
|
||||
"client must be an instance of ESM3 or ESM3ForgeInferenceClient"
|
||||
f"client must be an instance of ESM3 or ESM3ForgeInferenceClient. Got {type(client)}"
|
||||
)
|
||||
|
||||
self.client = client
|
||||
211
esm/sdk/forge.py
211
esm/sdk/forge.py
@@ -1,15 +1,10 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import inspect
|
||||
import pickle
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextvars import ContextVar
|
||||
from functools import wraps
|
||||
from typing import Any, Literal, Sequence, cast
|
||||
from typing import Any, Sequence
|
||||
|
||||
import torch
|
||||
from attr import asdict
|
||||
from tenacity import retry, retry_if_result, stop_after_attempt, wait_exponential
|
||||
|
||||
from esm.sdk.api import (
|
||||
MSA,
|
||||
@@ -29,17 +24,11 @@ from esm.sdk.api import (
|
||||
SamplingTrackConfig,
|
||||
)
|
||||
from esm.sdk.base_forge_client import _BaseForgeInferenceClient
|
||||
from esm.sdk.retry import retry_decorator
|
||||
from esm.utils.constants.api import MIMETYPE_ES_PICKLE
|
||||
from esm.utils.misc import (
|
||||
deserialize_tensors,
|
||||
maybe_list,
|
||||
maybe_tensor,
|
||||
)
|
||||
from esm.utils.sampling import validate_sampling_config
|
||||
from esm.utils.misc import deserialize_tensors, maybe_list, maybe_tensor
|
||||
from esm.utils.types import FunctionAnnotation
|
||||
|
||||
skip_retries_var = ContextVar("skip_retries", default=False)
|
||||
|
||||
|
||||
def _list_to_function_annotations(l) -> list[FunctionAnnotation] | None:
|
||||
if l is None or len(l) <= 0:
|
||||
@@ -47,21 +36,17 @@ def _list_to_function_annotations(l) -> list[FunctionAnnotation] | None:
|
||||
return [FunctionAnnotation(*t) for t in l]
|
||||
|
||||
|
||||
def retry_if_specific_error(exception):
|
||||
"""
|
||||
We only retry on specific errors.
|
||||
Currently we retry for 502 (bad gateway) and 429 (rate limit)
|
||||
"""
|
||||
return isinstance(exception, ESMProteinError) and exception.error_code in {
|
||||
429,
|
||||
502,
|
||||
504,
|
||||
}
|
||||
def _maybe_logits(data: dict[str, Any], track: str, return_bytes: bool = False):
|
||||
ret = data.get("logits", {}).get(track, None)
|
||||
# TODO(s22chan): just return this when removing return_bytes
|
||||
return ret if ret is None or not return_bytes else maybe_tensor(ret)
|
||||
|
||||
|
||||
def log_retry_attempt(retry_state):
|
||||
print(
|
||||
f"Retrying... Attempt {retry_state.attempt_number} after {retry_state.next_action.sleep}s due to: {retry_state.outcome.result()}"
|
||||
def _maybe_b64_decode(obj, return_bytes: bool):
|
||||
return (
|
||||
deserialize_tensors(base64.b64decode(obj))
|
||||
if return_bytes and isinstance(obj, str)
|
||||
else obj
|
||||
)
|
||||
|
||||
|
||||
@@ -77,45 +62,6 @@ def _validate_protein_tensor_input(input):
|
||||
)
|
||||
|
||||
|
||||
def retry_decorator(func):
|
||||
"""
|
||||
A static method that returns a retry decorator. This decorator uses the
|
||||
instance's retry settings.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(instance, *args, **kwargs):
|
||||
if skip_retries_var.get():
|
||||
return await func(instance, *args, **kwargs)
|
||||
retry_decorator = retry(
|
||||
retry=retry_if_result(retry_if_specific_error),
|
||||
wait=wait_exponential(
|
||||
multiplier=1, min=instance.min_retry_wait, max=instance.max_retry_wait
|
||||
),
|
||||
stop=stop_after_attempt(instance.max_retry_attempts),
|
||||
before_sleep=log_retry_attempt,
|
||||
)
|
||||
# Apply the retry decorator to the function
|
||||
return await retry_decorator(func)(instance, *args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(instance, *args, **kwargs):
|
||||
if skip_retries_var.get():
|
||||
return func(instance, *args, **kwargs)
|
||||
retry_decorator = retry(
|
||||
retry=retry_if_result(retry_if_specific_error),
|
||||
wait=wait_exponential(
|
||||
multiplier=1, min=instance.min_retry_wait, max=instance.max_retry_wait
|
||||
),
|
||||
stop=stop_after_attempt(instance.max_retry_attempts),
|
||||
before_sleep=log_retry_attempt,
|
||||
)
|
||||
# Apply the retry decorator to the function
|
||||
return retry_decorator(func)(instance, *args, **kwargs)
|
||||
|
||||
return async_wrapper if inspect.iscoroutinefunction(func) else wrapper
|
||||
|
||||
|
||||
class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -147,13 +93,9 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _process_fold_request(
|
||||
sequence: str,
|
||||
model_name: str | None,
|
||||
):
|
||||
def _process_fold_request(sequence: str, model_name: str | None):
|
||||
request: dict[str, Any] = {"sequence": sequence}
|
||||
|
||||
|
||||
request["model"] = model_name
|
||||
|
||||
return request
|
||||
@@ -164,12 +106,15 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
||||
sequence=sequence,
|
||||
coordinates=maybe_tensor(data["coordinates"], convert_none_to_nan=True),
|
||||
ptm=maybe_tensor(data.get("ptm", None)),
|
||||
plddt=maybe_tensor(data.get("plddt", None)),
|
||||
plddt=maybe_tensor(data.get("plddt", None), convert_none_to_nan=True),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def process_inverse_fold_request(
|
||||
coordinates: torch.Tensor, config: InverseFoldingConfig, model_name: str | None
|
||||
coordinates: torch.Tensor,
|
||||
sequence: str | None,
|
||||
config: InverseFoldingConfig,
|
||||
model_name: str | None,
|
||||
):
|
||||
inverse_folding_config = {
|
||||
"invalid_ids": config.invalid_ids,
|
||||
@@ -178,6 +123,7 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
||||
request = {
|
||||
"coordinates": maybe_list(coordinates, convert_nan_to_none=True),
|
||||
"inverse_folding_config": inverse_folding_config,
|
||||
"sequence": sequence,
|
||||
}
|
||||
if model_name is not None:
|
||||
request["model"] = model_name
|
||||
@@ -222,13 +168,13 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
||||
del potential_sequence_of_concern
|
||||
|
||||
request = self._process_fold_request(
|
||||
sequence,
|
||||
model_name if model_name is not None else self.model,
|
||||
sequence, model_name if model_name is not None else self.model
|
||||
)
|
||||
|
||||
# Intentionally not catching errors, so our higher level logic such as automatic
|
||||
# batch runner gets a chance to handle different errors properly.
|
||||
data = await self._async_post("fold", request)
|
||||
try:
|
||||
data = await self._async_post("fold", request)
|
||||
except ESMProteinError as e:
|
||||
return e
|
||||
|
||||
return self._process_fold_response(data, sequence)
|
||||
|
||||
@@ -249,16 +195,17 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
||||
potential_sequence_of_concern: this parameter is largely deprecated
|
||||
and ignored by the folding endpoint.
|
||||
"""
|
||||
|
||||
del potential_sequence_of_concern
|
||||
|
||||
request = self._process_fold_request(
|
||||
sequence,
|
||||
model_name if model_name is not None else self.model,
|
||||
sequence, model_name if model_name is not None else self.model
|
||||
)
|
||||
|
||||
# Intentionally not catching errors, so our higher level logic such as automatic
|
||||
# batch runner gets a chance to handle different errors properly.
|
||||
data = self._post("fold", request)
|
||||
try:
|
||||
data = self._post("fold", request)
|
||||
except ESMProteinError as e:
|
||||
return e
|
||||
|
||||
return self._process_fold_response(data, sequence)
|
||||
|
||||
@@ -268,6 +215,7 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
||||
coordinates: torch.Tensor,
|
||||
config: InverseFoldingConfig,
|
||||
potential_sequence_of_concern: bool,
|
||||
sequence: str | None = None,
|
||||
model_name: str | None = None,
|
||||
) -> ESMProtein | ESMProteinError:
|
||||
"""Generate protein sequence from its structure.
|
||||
@@ -282,14 +230,18 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
||||
model_name: Override the client level model name if needed.
|
||||
"""
|
||||
request = self.process_inverse_fold_request(
|
||||
coordinates, config, model_name if model_name is not None else self.model
|
||||
coordinates,
|
||||
sequence,
|
||||
config,
|
||||
model_name if model_name is not None else self.model,
|
||||
)
|
||||
|
||||
# Intentionally not catching errors, so our higher level logic such as automatic
|
||||
# batch runner gets a chance to handle different errors properly.
|
||||
data = await self._async_post(
|
||||
"inverse_fold", request, potential_sequence_of_concern
|
||||
)
|
||||
try:
|
||||
data = await self._async_post(
|
||||
"inverse_fold", request, potential_sequence_of_concern
|
||||
)
|
||||
except ESMProteinError as e:
|
||||
return e
|
||||
|
||||
return ESMProtein(sequence=data["sequence"])
|
||||
|
||||
@@ -299,6 +251,7 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
||||
coordinates: torch.Tensor,
|
||||
config: InverseFoldingConfig,
|
||||
potential_sequence_of_concern: bool,
|
||||
sequence: str | None = None,
|
||||
model_name: str | None = None,
|
||||
) -> ESMProtein | ESMProteinError:
|
||||
"""Generate protein sequence from its structure.
|
||||
@@ -313,12 +266,16 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
||||
model_name: Override the client level model name if needed.
|
||||
"""
|
||||
request = self.process_inverse_fold_request(
|
||||
coordinates, config, model_name if model_name is not None else self.model
|
||||
coordinates,
|
||||
sequence,
|
||||
config,
|
||||
model_name if model_name is not None else self.model,
|
||||
)
|
||||
|
||||
# Intentionally not catching errors, so our higher level logic such as automatic
|
||||
# batch runner gets a chance to handle different errors properly.
|
||||
data = self._post("inverse_fold", request, potential_sequence_of_concern)
|
||||
try:
|
||||
data = self._post("inverse_fold", request, potential_sequence_of_concern)
|
||||
except ESMProteinError as e:
|
||||
return e
|
||||
|
||||
return ESMProtein(sequence=data["sequence"])
|
||||
|
||||
@@ -442,6 +399,8 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient, _BaseForgeInferenceClient):
|
||||
def _process_forward_and_sample_request(
|
||||
input: ESMProteinTensor, sampling_configuration: SamplingConfig, model_name: str
|
||||
) -> dict[str, Any]:
|
||||
from esm.utils.sampling import validate_sampling_config
|
||||
|
||||
_validate_protein_tensor_input(input)
|
||||
validate_sampling_config(sampling_configuration, on_invalid="raise")
|
||||
|
||||
@@ -623,7 +582,9 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient, _BaseForgeInferenceClient):
|
||||
"function": config.function,
|
||||
"residue_annotations": config.residue_annotations,
|
||||
"return_embeddings": config.return_embeddings,
|
||||
"return_mean_embedding": config.return_mean_embedding,
|
||||
"return_hidden_states": config.return_hidden_states,
|
||||
"return_mean_hidden_states": config.return_mean_hidden_states,
|
||||
"ith_hidden_layer": config.ith_hidden_layer,
|
||||
}
|
||||
|
||||
@@ -634,34 +595,28 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient, _BaseForgeInferenceClient):
|
||||
def _process_logits_response(
|
||||
data: dict[str, Any], return_bytes: bool
|
||||
) -> LogitsOutput:
|
||||
def _maybe_logits(track: str):
|
||||
ret = data.get("logits", {}).get(track, None)
|
||||
# TODO(s22chan): just return this when removing return_bytes
|
||||
return ret if ret is None or not return_bytes else maybe_tensor(ret)
|
||||
|
||||
def _maybe_b64_decode(obj):
|
||||
return (
|
||||
deserialize_tensors(base64.b64decode(obj))
|
||||
if return_bytes and isinstance(obj, str)
|
||||
else obj
|
||||
)
|
||||
|
||||
logits = _maybe_b64_decode(data["logits"])
|
||||
logits = _maybe_b64_decode(data["logits"], return_bytes)
|
||||
data["logits"] = dict(logits) if logits is not None else logits
|
||||
data["embeddings"] = _maybe_b64_decode(data["embeddings"])
|
||||
data["hidden_states"] = _maybe_b64_decode(data["hidden_states"])
|
||||
data["embeddings"] = _maybe_b64_decode(data["embeddings"], return_bytes)
|
||||
data["hidden_states"] = _maybe_b64_decode(data["hidden_states"], return_bytes)
|
||||
|
||||
return LogitsOutput(
|
||||
logits=ForwardTrackData(
|
||||
sequence=_maybe_logits("sequence"),
|
||||
structure=_maybe_logits("structure"),
|
||||
secondary_structure=_maybe_logits("secondary_structure"),
|
||||
sasa=_maybe_logits("sasa"),
|
||||
function=_maybe_logits("function"),
|
||||
sequence=_maybe_logits(data, "sequence", return_bytes),
|
||||
structure=_maybe_logits(data, "structure", return_bytes),
|
||||
secondary_structure=_maybe_logits(
|
||||
data, "secondary_structure", return_bytes
|
||||
),
|
||||
sasa=_maybe_logits(data, "sasa", return_bytes),
|
||||
function=_maybe_logits(data, "function", return_bytes),
|
||||
),
|
||||
embeddings=maybe_tensor(data["embeddings"]),
|
||||
residue_annotation_logits=_maybe_logits("residue_annotation"),
|
||||
mean_embedding=data["mean_embedding"],
|
||||
residue_annotation_logits=_maybe_logits(
|
||||
data, "residue_annotation", return_bytes
|
||||
),
|
||||
hidden_states=maybe_tensor(data["hidden_states"]),
|
||||
mean_hidden_state=maybe_tensor(data["mean_hidden_state"]),
|
||||
)
|
||||
|
||||
@retry_decorator
|
||||
@@ -741,8 +696,8 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient, _BaseForgeInferenceClient):
|
||||
try:
|
||||
res = self.async_generate(protein, config)
|
||||
return await res
|
||||
except Exception as e:
|
||||
return ESMProteinError(500, str(e))
|
||||
except ESMProteinError as e:
|
||||
return e
|
||||
|
||||
tasks = [
|
||||
safe_generate(protein, config) for protein, config in zip(inputs, configs)
|
||||
@@ -1009,6 +964,7 @@ class ESMCForgeInferenceClient(ESMCInferenceClient, _BaseForgeInferenceClient):
|
||||
logits_config = {
|
||||
"sequence": config.sequence,
|
||||
"return_embeddings": config.return_embeddings,
|
||||
"return_mean_embedding": config.return_mean_embedding,
|
||||
"return_hidden_states": config.return_hidden_states,
|
||||
"ith_hidden_layer": config.ith_hidden_layer,
|
||||
}
|
||||
@@ -1019,26 +975,17 @@ class ESMCForgeInferenceClient(ESMCInferenceClient, _BaseForgeInferenceClient):
|
||||
def _process_logits_response(
|
||||
data: dict[str, Any], return_bytes: bool
|
||||
) -> LogitsOutput:
|
||||
def _maybe_logits(track: str):
|
||||
ret = data.get("logits", {}).get(track, None)
|
||||
# TODO(s22chan): just return this when removing return_bytes
|
||||
return ret if ret is None or not return_bytes else maybe_tensor(ret)
|
||||
|
||||
def _maybe_b64_decode(obj):
|
||||
return (
|
||||
deserialize_tensors(base64.b64decode(obj))
|
||||
if return_bytes and isinstance(obj, str)
|
||||
else obj
|
||||
)
|
||||
|
||||
logits = _maybe_b64_decode(data["logits"])
|
||||
logits = _maybe_b64_decode(data["logits"], return_bytes)
|
||||
data["logits"] = dict(logits) if logits is not None else logits
|
||||
data["embeddings"] = _maybe_b64_decode(data["embeddings"])
|
||||
data["hidden_states"] = _maybe_b64_decode(data["hidden_states"])
|
||||
data["embeddings"] = _maybe_b64_decode(data["embeddings"], return_bytes)
|
||||
data["hidden_states"] = _maybe_b64_decode(data["hidden_states"], return_bytes)
|
||||
|
||||
output = LogitsOutput(
|
||||
logits=ForwardTrackData(sequence=_maybe_logits("sequence")),
|
||||
logits=ForwardTrackData(
|
||||
sequence=_maybe_logits(data, "sequence", return_bytes)
|
||||
),
|
||||
embeddings=maybe_tensor(data["embeddings"]),
|
||||
mean_embedding=data["mean_embedding"],
|
||||
hidden_states=maybe_tensor(data["hidden_states"]),
|
||||
)
|
||||
return output
|
||||
|
||||
76
esm/sdk/retry.py
Normal file
76
esm/sdk/retry.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import inspect
|
||||
from contextvars import ContextVar
|
||||
from functools import wraps
|
||||
|
||||
import httpx
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
retry_if_result,
|
||||
stop_after_attempt,
|
||||
wait_incrementing,
|
||||
)
|
||||
|
||||
from esm.sdk.api import ESMProteinError
|
||||
|
||||
skip_retries_var = ContextVar("skip_retries", default=False)
|
||||
|
||||
|
||||
def retry_if_specific_error(exception):
|
||||
"""
|
||||
We only retry on specific errors.
|
||||
"""
|
||||
return isinstance(exception, ESMProteinError) and exception.error_code in {
|
||||
429,
|
||||
500,
|
||||
502,
|
||||
504,
|
||||
500,
|
||||
}
|
||||
|
||||
|
||||
def log_retry_attempt(retry_state):
|
||||
print(
|
||||
f"Retrying... Attempt {retry_state.attempt_number} after {retry_state.next_action.sleep}s due to: {retry_state.outcome.result()}"
|
||||
)
|
||||
|
||||
|
||||
def retry_decorator(func):
|
||||
"""
|
||||
A static method that returns a retry decorator. This decorator uses the
|
||||
instance's retry settings.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(instance, *args, **kwargs):
|
||||
if skip_retries_var.get():
|
||||
return await func(instance, *args, **kwargs)
|
||||
retry_decorator = retry(
|
||||
retry=retry_if_result(retry_if_specific_error)
|
||||
| retry_if_exception_type(httpx.ConnectTimeout), # ADDED
|
||||
wait=wait_incrementing(
|
||||
increment=1, start=instance.min_retry_wait, max=instance.max_retry_wait
|
||||
),
|
||||
stop=stop_after_attempt(instance.max_retry_attempts),
|
||||
before_sleep=log_retry_attempt,
|
||||
)
|
||||
# Apply the retry decorator to the function
|
||||
return await retry_decorator(func)(instance, *args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(instance, *args, **kwargs):
|
||||
if skip_retries_var.get():
|
||||
return func(instance, *args, **kwargs)
|
||||
retry_decorator = retry(
|
||||
retry=retry_if_result(retry_if_specific_error)
|
||||
| retry_if_exception_type(httpx.ConnectTimeout), # ADDED
|
||||
wait=wait_incrementing(
|
||||
increment=1, start=instance.min_retry_wait, max=instance.max_retry_wait
|
||||
),
|
||||
stop=stop_after_attempt(instance.max_retry_attempts),
|
||||
before_sleep=log_retry_attempt,
|
||||
)
|
||||
# Apply the retry decorator to the function
|
||||
return retry_decorator(func)(instance, *args, **kwargs)
|
||||
|
||||
return async_wrapper if inspect.iscoroutinefunction(func) else wrapper
|
||||
@@ -22,7 +22,7 @@ class SequenceStructureSageMakerClient(SequenceStructureForgeInferenceClient):
|
||||
|
||||
self._boto3_client = boto3.client(service_name="sagemaker-runtime")
|
||||
|
||||
def _post(self, endpoint, request, potential_sequence_of_concern):
|
||||
def _post(self, endpoint, request, potential_sequence_of_concern: bool = False):
|
||||
request["potential_sequence_of_concern"] = potential_sequence_of_concern
|
||||
request["model"] = request.get("model", None)
|
||||
invocations_request = {
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol
|
||||
|
||||
from esm.utils.constants.models import (
|
||||
ESM3_OPEN_SMALL,
|
||||
normalize_model_name,
|
||||
)
|
||||
from esm.utils.constants.models import ESM3_OPEN_SMALL, normalize_model_name
|
||||
|
||||
from .function_tokenizer import InterProQuantizedTokenizer
|
||||
from .residue_tokenizer import ResidueAnnotationsTokenizer
|
||||
|
||||
@@ -10,24 +10,12 @@ from esm.models.function_decoder import FunctionTokenDecoder
|
||||
from esm.models.vqvae import StructureTokenDecoder
|
||||
from esm.sdk.api import ESMProtein, ESMProteinTensor
|
||||
from esm.tokenization import TokenizerCollectionProtocol
|
||||
from esm.tokenization.function_tokenizer import (
|
||||
InterProQuantizedTokenizer,
|
||||
)
|
||||
from esm.tokenization.residue_tokenizer import (
|
||||
ResidueAnnotationsTokenizer,
|
||||
)
|
||||
from esm.tokenization.sasa_tokenizer import (
|
||||
SASADiscretizingTokenizer,
|
||||
)
|
||||
from esm.tokenization.sequence_tokenizer import (
|
||||
EsmSequenceTokenizer,
|
||||
)
|
||||
from esm.tokenization.ss_tokenizer import (
|
||||
SecondaryStructureTokenizer,
|
||||
)
|
||||
from esm.tokenization.structure_tokenizer import (
|
||||
StructureTokenizer,
|
||||
)
|
||||
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
|
||||
from esm.tokenization.residue_tokenizer import ResidueAnnotationsTokenizer
|
||||
from esm.tokenization.sasa_tokenizer import SASADiscretizingTokenizer
|
||||
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
|
||||
from esm.tokenization.ss_tokenizer import SecondaryStructureTokenizer
|
||||
from esm.tokenization.structure_tokenizer import StructureTokenizer
|
||||
from esm.tokenization.tokenizer_base import EsmTokenizerBase
|
||||
from esm.utils.constants import api as api_constants
|
||||
from esm.utils.constants import esm3 as C
|
||||
@@ -251,6 +239,7 @@ def assemble_message(headers: Mapping[str, str], response: Response) -> dict[str
|
||||
content_type = headers.get("Content-Type", "application/json")
|
||||
if content_type == api_constants.MIMETYPE_ES_PICKLE:
|
||||
return pickle.loads(response.content)
|
||||
elif content_type == "application/json":
|
||||
elif "application/json" in content_type:
|
||||
# Can handle something like "application/json; charset=utf-8"
|
||||
return response.json()
|
||||
raise ValueError(f"Unknown Content-Type: {content_type}")
|
||||
|
||||
@@ -7,26 +7,13 @@ from esm.models.vqvae import StructureTokenEncoder
|
||||
from esm.tokenization.function_tokenizer import (
|
||||
InterProQuantizedTokenizer as EsmFunctionTokenizer,
|
||||
)
|
||||
|
||||
from esm.tokenization.residue_tokenizer import (
|
||||
ResidueAnnotationsTokenizer,
|
||||
)
|
||||
from esm.tokenization.sasa_tokenizer import (
|
||||
SASADiscretizingTokenizer,
|
||||
)
|
||||
from esm.tokenization.sequence_tokenizer import (
|
||||
EsmSequenceTokenizer,
|
||||
)
|
||||
from esm.tokenization.ss_tokenizer import (
|
||||
SecondaryStructureTokenizer,
|
||||
)
|
||||
from esm.tokenization.structure_tokenizer import (
|
||||
StructureTokenizer,
|
||||
)
|
||||
from esm.tokenization.residue_tokenizer import ResidueAnnotationsTokenizer
|
||||
from esm.tokenization.sasa_tokenizer import SASADiscretizingTokenizer
|
||||
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
|
||||
from esm.tokenization.ss_tokenizer import SecondaryStructureTokenizer
|
||||
from esm.tokenization.structure_tokenizer import StructureTokenizer
|
||||
from esm.utils.constants import esm3 as C
|
||||
from esm.utils.function.encode_decode import (
|
||||
encode_function_annotations,
|
||||
)
|
||||
from esm.utils.function.encode_decode import encode_function_annotations
|
||||
from esm.utils.structure.protein_chain import ProteinChain
|
||||
from esm.utils.types import FunctionAnnotation
|
||||
|
||||
@@ -165,8 +152,6 @@ def tokenize_function_annotations(
|
||||
return function_tokens, residue_annotation_tokens
|
||||
|
||||
|
||||
|
||||
|
||||
# Tokenized Defaults
|
||||
def get_default_sequence_tokens(
|
||||
sequence_length: int, sequence_tokenizer: EsmSequenceTokenizer
|
||||
@@ -242,5 +227,3 @@ def get_default_residue_annotation_tokens(
|
||||
residue_annotation_tokens[0] = residue_annotation_tokenizer.bos_token_id
|
||||
residue_annotation_tokens[-1] = residue_annotation_tokenizer.eos_token_id
|
||||
return residue_annotation_tokens
|
||||
|
||||
|
||||
|
||||
@@ -7,10 +7,7 @@ from typing import Any, Callable, List
|
||||
from tqdm import tqdm
|
||||
|
||||
from esm.sdk.api import ESMProteinError
|
||||
from esm.sdk.forge import (
|
||||
retry_if_specific_error,
|
||||
skip_retries_var,
|
||||
)
|
||||
from esm.sdk.retry import retry_if_specific_error, skip_retries_var
|
||||
|
||||
TQDM_BAR_FORMAT = (
|
||||
"{desc:<12}{percentage:3.0f}%|{bar:24}| {n_fmt}/{total_fmt} "
|
||||
@@ -25,7 +22,7 @@ class AIMDRateLimiter:
|
||||
self,
|
||||
initial_concurrency: int = 32,
|
||||
min_concurrency: int = 1,
|
||||
max_concurrency: int = 512,
|
||||
max_concurrency: int = 64,
|
||||
step_up: int = 1,
|
||||
):
|
||||
self.concurrency = initial_concurrency
|
||||
@@ -56,8 +53,10 @@ class ForgeBatchExecutor:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, max_attempts: int = 10, max_workers: int = 512, show_progress: bool = True
|
||||
self, max_attempts: int = 10, max_workers: int = 64, show_progress: bool = True
|
||||
):
|
||||
if max_workers > 64:
|
||||
raise ValueError("max_workers must be less than 64")
|
||||
self.rate_limiter = AIMDRateLimiter(max_concurrency=max_workers)
|
||||
self.max_attempts = max_attempts
|
||||
self.show_progress = show_progress
|
||||
|
||||
@@ -3,16 +3,9 @@ from typing import Sequence
|
||||
|
||||
import torch
|
||||
|
||||
from esm.models.function_decoder import (
|
||||
FunctionTokenDecoder,
|
||||
merge_annotations,
|
||||
)
|
||||
from esm.tokenization.function_tokenizer import (
|
||||
InterProQuantizedTokenizer,
|
||||
)
|
||||
from esm.tokenization.residue_tokenizer import (
|
||||
ResidueAnnotationsTokenizer,
|
||||
)
|
||||
from esm.models.function_decoder import FunctionTokenDecoder, merge_annotations
|
||||
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
|
||||
from esm.tokenization.residue_tokenizer import ResidueAnnotationsTokenizer
|
||||
from esm.utils.constants import esm3 as C
|
||||
from esm.utils.types import FunctionAnnotation
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ class LSHTokenized:
|
||||
hyperplanes: dict[str, np.ndarray] = { # type: ignore
|
||||
str(i): table.hyperplanes for i, table in enumerate(self.tables)
|
||||
}
|
||||
np.savez(filepath, **hyperplanes)
|
||||
np.savez(filepath, **hyperplanes) # type: ignore
|
||||
|
||||
def __call__(self, array):
|
||||
tokens = np.stack([table(array) for table in self.tables], 1)
|
||||
|
||||
@@ -19,13 +19,8 @@ from esm.sdk.api import (
|
||||
SamplingConfig,
|
||||
SamplingTrackConfig,
|
||||
)
|
||||
from esm.tokenization import (
|
||||
EsmTokenizerBase,
|
||||
TokenizerCollectionProtocol,
|
||||
)
|
||||
from esm.tokenization.function_tokenizer import (
|
||||
InterProQuantizedTokenizer,
|
||||
)
|
||||
from esm.tokenization import EsmTokenizerBase, TokenizerCollectionProtocol
|
||||
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
|
||||
from esm.utils.constants import esm3 as C
|
||||
from esm.utils.misc import stack_variable_length_tensors
|
||||
from esm.utils.noise_schedules import NOISE_SCHEDULE_REGISTRY
|
||||
|
||||
@@ -1,143 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from evolutionaryscale.models.esm3v2 import Esm3v2
|
||||
from esm.sdk.api import (
|
||||
ESMProtein,
|
||||
ESMProteinTensor,
|
||||
GenerationConfig,
|
||||
)
|
||||
from evolutionaryscale.utils.env import ModelName
|
||||
from evolutionaryscale.utils.remote_inference.api_v1 import (
|
||||
ESM3RemoteModelInferenceClient,
|
||||
)
|
||||
from projects.forge.inference.utils.model import _load_esm_model
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def esm3_remote_inference_client():
|
||||
model = _load_esm_model(
|
||||
ModelName.ESM3_TINY_DEV, distributed_model=False, load_function_decoder=False
|
||||
)
|
||||
assert isinstance(model, Esm3v2)
|
||||
client = ESM3RemoteModelInferenceClient(
|
||||
model,
|
||||
tokenizers=model.tokenizers,
|
||||
device=torch.device("cuda"),
|
||||
enable_batched_runner=False,
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
def test_chain_break_tokens(esm3_remote_inference_client):
|
||||
tokenizer = esm3_remote_inference_client.tokenizers.sequence
|
||||
# 3 separate chains with 2 chainbreak tokens.
|
||||
sequence_with_chain_breaks = torch.tensor(
|
||||
[
|
||||
tokenizer.bos_token_id,
|
||||
20,
|
||||
20,
|
||||
20,
|
||||
20,
|
||||
tokenizer.chain_break_token_id,
|
||||
21,
|
||||
21,
|
||||
21,
|
||||
tokenizer.chain_break_token_id,
|
||||
22,
|
||||
22,
|
||||
22,
|
||||
tokenizer.eos_token_id,
|
||||
]
|
||||
)
|
||||
protein = esm3_remote_inference_client.generate(
|
||||
ESMProteinTensor(sequence=sequence_with_chain_breaks),
|
||||
# There are 10 tokens that actually need to be sampled.
|
||||
GenerationConfig(track="structure", num_steps=10),
|
||||
)
|
||||
|
||||
assert isinstance(protein, ESMProteinTensor)
|
||||
assert protein.structure is not None
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
def test_num_decoding_steps_more_than_mask_tokens(esm3_remote_inference_client):
|
||||
protein = esm3_remote_inference_client.generate(
|
||||
esm3_remote_inference_client.encode(
|
||||
ESMProtein(sequence="CDEFG")
|
||||
), # sequence of 5.
|
||||
GenerationConfig(track="structure", num_steps=10), # use 10 decoding steps.
|
||||
)
|
||||
# Client should handle over-specification of decoding steps.
|
||||
# TODO: This should be a warning.
|
||||
assert isinstance(protein, ESMProteinTensor)
|
||||
assert protein.structure is not None
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
def test_num_decoding_steps_more_than_mask_tokens_batched(esm3_remote_inference_client):
|
||||
protein_list = esm3_remote_inference_client.batch_generate(
|
||||
inputs=[
|
||||
esm3_remote_inference_client.encode(ESMProtein(sequence="CDEFG")),
|
||||
esm3_remote_inference_client.encode(ESMProtein(sequence="ABCDEFG")),
|
||||
esm3_remote_inference_client.encode(ESMProtein(sequence="AB__EFG")),
|
||||
],
|
||||
configs=[
|
||||
GenerationConfig(track="structure", num_steps=10),
|
||||
GenerationConfig(track="structure", num_steps=3),
|
||||
GenerationConfig(track="sequence", num_steps=20),
|
||||
],
|
||||
)
|
||||
# Client should handle over-specification of decoding steps.
|
||||
# TODO: This should be a warning.
|
||||
assert isinstance(protein_list[0], ESMProteinTensor)
|
||||
assert protein_list[0].structure is not None
|
||||
assert isinstance(protein_list[1], ESMProteinTensor)
|
||||
assert protein_list[1].structure is not None
|
||||
assert isinstance(protein_list[2], ESMProteinTensor)
|
||||
assert protein_list[2].sequence is not None
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
def test_encode_chainbreak_token(esm3_remote_inference_client):
|
||||
protein = esm3_remote_inference_client.encode(ESMProtein(sequence="MSTNP|KPQKK"))
|
||||
assert isinstance(protein, ESMProteinTensor)
|
||||
assert protein.sequence is not None
|
||||
assert (
|
||||
protein.sequence[6]
|
||||
== esm3_remote_inference_client.tokenizers.sequence.chain_break_token_id
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
def test_generation_with_chainbreak_token(esm3_remote_inference_client):
|
||||
chainbreak_sequence = torch.tensor(
|
||||
[
|
||||
esm3_remote_inference_client.tokenizers.sequence.bos_token_id,
|
||||
20,
|
||||
8,
|
||||
11,
|
||||
17,
|
||||
14,
|
||||
esm3_remote_inference_client.tokenizers.sequence.chain_break_token_id,
|
||||
15,
|
||||
14,
|
||||
16,
|
||||
15,
|
||||
15,
|
||||
esm3_remote_inference_client.tokenizers.sequence.eos_token_id,
|
||||
]
|
||||
)
|
||||
|
||||
protein = esm3_remote_inference_client.generate(
|
||||
ESMProteinTensor(sequence=chainbreak_sequence),
|
||||
GenerationConfig(track="structure", num_steps=1),
|
||||
)
|
||||
# Can't specify more decoding steps than masks available.
|
||||
assert isinstance(protein, ESMProteinTensor)
|
||||
assert protein.structure is not None
|
||||
assert (
|
||||
protein.structure[6]
|
||||
== esm3_remote_inference_client.tokenizers.structure.chain_break_token_id
|
||||
)
|
||||
@@ -33,15 +33,15 @@ def slice_python_object_as_numpy(
|
||||
>>> slice_python_object_as_numpy(obj, np.arange(5) < 3)
|
||||
[1, 2, 3]
|
||||
"""
|
||||
if isinstance(idx, int):
|
||||
idx = [idx]
|
||||
if np.isscalar(idx):
|
||||
idx = [int(idx)] # type: ignore
|
||||
|
||||
if isinstance(idx, np.ndarray) and idx.dtype == bool:
|
||||
sliced_obj = [obj[i] for i in np.where(idx)[0]]
|
||||
elif isinstance(idx, slice):
|
||||
sliced_obj = obj[idx]
|
||||
else:
|
||||
sliced_obj = [obj[i] for i in idx]
|
||||
sliced_obj = [obj[i] for i in idx] # type: ignore
|
||||
|
||||
match obj, sliced_obj:
|
||||
case str(), list():
|
||||
@@ -156,6 +156,37 @@ def stack_variable_length_tensors(
|
||||
return array
|
||||
|
||||
|
||||
def binpack(
|
||||
tensor: torch.Tensor, sequence_id: torch.Tensor | None, pad_value: int | float
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
tensor (Tensor): [B, L, ...]
|
||||
|
||||
Returns:
|
||||
Tensor: [B_binpacked, L_binpacked, ...]
|
||||
"""
|
||||
if sequence_id is None:
|
||||
return tensor
|
||||
|
||||
num_sequences = sequence_id.max(dim=-1).values + 1
|
||||
|
||||
dims = sequence_id.shape + tensor.shape[2:]
|
||||
output_tensor = torch.full(
|
||||
dims, fill_value=pad_value, dtype=tensor.dtype, device=tensor.device
|
||||
)
|
||||
|
||||
idx = 0
|
||||
for batch_idx, (batch_seqid, batch_num_sequences) in enumerate(
|
||||
zip(sequence_id, num_sequences)
|
||||
):
|
||||
for seqid in range(batch_num_sequences):
|
||||
mask = batch_seqid == seqid
|
||||
output_tensor[batch_idx, mask] = tensor[idx, : mask.sum()]
|
||||
idx += 1
|
||||
return output_tensor
|
||||
|
||||
|
||||
def unbinpack(
|
||||
tensor: torch.Tensor, sequence_id: torch.Tensor | None, pad_value: int | float
|
||||
):
|
||||
@@ -280,9 +311,18 @@ def maybe_list(x, convert_nan_to_none: bool = False) -> list | None:
|
||||
return None
|
||||
if not convert_nan_to_none:
|
||||
return x.tolist()
|
||||
nan_mask = torch.isnan(x)
|
||||
np_arr = x.cpu().numpy().astype(object)
|
||||
np_arr[nan_mask.cpu().numpy()] = None
|
||||
|
||||
# Handle both torch.tensor and np.ndarray input.
|
||||
if isinstance(x, torch.Tensor):
|
||||
nan_mask = torch.isnan(x).cpu().numpy()
|
||||
np_arr = x.cpu().numpy().astype(object)
|
||||
elif isinstance(x, np.ndarray):
|
||||
nan_mask = np.isnan(x)
|
||||
np_arr = x.astype(object)
|
||||
else:
|
||||
raise TypeError("maybe_list can only work with torch.tensor or np.ndarray.")
|
||||
|
||||
np_arr[nan_mask] = None
|
||||
return np_arr.tolist()
|
||||
|
||||
|
||||
@@ -313,7 +353,6 @@ def get_chainbreak_boundaries_from_sequence(sequence: Sequence[str]) -> np.ndarr
|
||||
return chain_boundaries
|
||||
|
||||
|
||||
# TODO(return_bytes): remove when retiring return_bytes on SageMaker
|
||||
def deserialize_tensors(b: bytes) -> Any:
|
||||
buf = BytesIO(zstd.ZSTD_uncompress(b))
|
||||
d = torch.load(buf, map_location="cpu", weights_only=False)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,18 +5,9 @@ import attr
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from esm.sdk.api import (
|
||||
ESMProteinTensor,
|
||||
SamplingConfig,
|
||||
SamplingTrackConfig,
|
||||
)
|
||||
from esm.tokenization import (
|
||||
TokenizerCollectionProtocol,
|
||||
get_invalid_tokenizer_ids,
|
||||
)
|
||||
from esm.tokenization.function_tokenizer import (
|
||||
InterProQuantizedTokenizer,
|
||||
)
|
||||
from esm.sdk.api import ESMProteinTensor, SamplingConfig, SamplingTrackConfig
|
||||
from esm.tokenization import TokenizerCollectionProtocol, get_invalid_tokenizer_ids
|
||||
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
|
||||
from esm.utils.constants.esm3 import (
|
||||
MAX_RESIDUE_ANNOTATIONS,
|
||||
SASA_DISCRETIZATION_BOUNDARIES,
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as T
|
||||
from abc import ABC
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from typing_extensions import Self
|
||||
|
||||
from esm.utils.misc import fp32_autocast_context
|
||||
|
||||
|
||||
@T.runtime_checkable
|
||||
class Rotation(T.Protocol):
|
||||
class Rotation(ABC):
|
||||
@classmethod
|
||||
def identity(cls, shape: tuple[int, ...], **tensor_kwargs) -> Self: ...
|
||||
|
||||
@@ -34,6 +35,8 @@ class Rotation(T.Protocol):
|
||||
|
||||
def as_matrix(self) -> RotationMatrix: ...
|
||||
|
||||
def as_quat(self, normalize: bool = False) -> RotationQuat: ...
|
||||
|
||||
def compose(self, other: Self) -> Self:
|
||||
# To be safe, we force users to explicitly convert between rotation types.
|
||||
...
|
||||
@@ -87,7 +90,8 @@ class RotationMatrix(Rotation):
|
||||
assert rots.shape[-1] == 3
|
||||
assert rots.shape[-2] == 3
|
||||
# Force full precision
|
||||
self._rots = rots.to(torch.float32)
|
||||
rots = rots.to(torch.float32)
|
||||
self._rots = rots
|
||||
|
||||
@classmethod
|
||||
def identity(cls, shape, **tensor_kwargs):
|
||||
@@ -98,9 +102,7 @@ class RotationMatrix(Rotation):
|
||||
|
||||
@classmethod
|
||||
def random(cls, shape, **tensor_kwargs):
|
||||
v1 = torch.randn((*shape, 3), **tensor_kwargs)
|
||||
v2 = torch.randn((*shape, 3), **tensor_kwargs)
|
||||
return cls(_graham_schmidt(v1, v2))
|
||||
return RotationQuat.random(shape, **tensor_kwargs).as_matrix()
|
||||
|
||||
def __getitem__(self, idx: T.Any) -> RotationMatrix:
|
||||
indices = (idx,) if isinstance(idx, int) or idx is None else tuple(idx)
|
||||
@@ -113,6 +115,49 @@ class RotationMatrix(Rotation):
|
||||
def as_matrix(self) -> RotationMatrix:
|
||||
return self
|
||||
|
||||
def as_quat(self, normalize: bool = False) -> RotationQuat:
|
||||
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
|
||||
self._rots.flatten(-2), dim=-1
|
||||
)
|
||||
q_abs = _sqrt_subgradient(
|
||||
torch.stack(
|
||||
[
|
||||
1.0 + m00 + m11 + m22,
|
||||
1.0 + m00 - m11 - m22,
|
||||
1.0 - m00 + m11 - m22,
|
||||
1.0 - m00 - m11 + m22,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
)
|
||||
# we produce the desired quaternion multiplied by each of r, i, j, k
|
||||
quat_by_rijk = torch.stack(
|
||||
[
|
||||
x
|
||||
for lst in [
|
||||
[q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01],
|
||||
[m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20],
|
||||
[m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21],
|
||||
[m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2],
|
||||
]
|
||||
for x in lst
|
||||
],
|
||||
dim=-1,
|
||||
).unflatten(-1, (4, 4))
|
||||
|
||||
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
|
||||
# the candidate won't be picked.
|
||||
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
|
||||
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
|
||||
|
||||
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
|
||||
# forall i; we pick the best-conditioned one (with the largest denominator)
|
||||
# We manually implement one_hot so torch.compile works
|
||||
one_hot = torch.zeros_like(q_abs, dtype=torch.bool)
|
||||
one_hot.scatter_(-1, q_abs.argmax(dim=-1, keepdim=True), True)
|
||||
quat = quat_candidates[one_hot, :].reshape(q_abs.shape)
|
||||
return RotationQuat(quat)
|
||||
|
||||
def compose(self, other: RotationMatrix) -> RotationMatrix:
|
||||
with fp32_autocast_context(self._rots.device.type):
|
||||
return RotationMatrix(self._rots @ other._rots)
|
||||
@@ -147,6 +192,81 @@ class RotationMatrix(Rotation):
|
||||
return RotationMatrix(_graham_schmidt(x_axis, xy_plane, eps))
|
||||
|
||||
|
||||
class RotationQuat(Rotation):
|
||||
def __init__(self, quats: torch.Tensor, normalized=False):
|
||||
assert quats.shape[-1] == 4
|
||||
self._normalized = normalized
|
||||
# Force float32 as well
|
||||
if normalized:
|
||||
self._quats = F.normalize(quats.to(torch.float32), dim=-1)
|
||||
self._quats = self._quats.where(self._quats[..., :1] >= 0, -self._quats)
|
||||
else:
|
||||
self._quats = quats.to(torch.float32)
|
||||
|
||||
@classmethod
|
||||
def identity(cls, shape, **tensor_kwargs):
|
||||
q = torch.ones((*shape, 4), **tensor_kwargs)
|
||||
mult = torch.tensor([1, 0, 0, 0], device=q.device)
|
||||
return RotationQuat(q * mult)
|
||||
|
||||
@classmethod
|
||||
def random(cls, shape, **tensor_kwargs):
|
||||
quat = torch.randn((*shape, 4), **tensor_kwargs)
|
||||
return RotationQuat(quat, normalized=True)
|
||||
|
||||
def __getitem__(self, idx: T.Any) -> RotationQuat:
|
||||
indices = (idx,) if isinstance(idx, int) or idx is None else tuple(idx)
|
||||
return RotationQuat(self._quats[indices + (slice(None),)])
|
||||
|
||||
@property
|
||||
def shape(self) -> torch.Size:
|
||||
return self._quats.shape[:-1]
|
||||
|
||||
def compose(self, other: RotationQuat) -> RotationQuat:
|
||||
with fp32_autocast_context(self._quats.device.type):
|
||||
return RotationQuat(_quat_mult(self._quats, other._quats))
|
||||
|
||||
def convert_compose(self, other: Rotation):
|
||||
return self.compose(other.as_quat())
|
||||
|
||||
def as_matrix(self) -> RotationMatrix:
|
||||
q = self.normalized().tensor
|
||||
r, i, j, k = torch.unbind(q, -1)
|
||||
two_s = 2.0 / torch.linalg.norm(q, dim=-1)
|
||||
|
||||
o = torch.stack(
|
||||
(
|
||||
1 - two_s * (j * j + k * k),
|
||||
two_s * (i * j - k * r),
|
||||
two_s * (i * k + j * r),
|
||||
two_s * (i * j + k * r),
|
||||
1 - two_s * (i * i + k * k),
|
||||
two_s * (j * k - i * r),
|
||||
two_s * (i * k - j * r),
|
||||
two_s * (j * k + i * r),
|
||||
1 - two_s * (i * i + j * j),
|
||||
),
|
||||
-1,
|
||||
)
|
||||
return RotationMatrix(o.reshape(q.shape[:-1] + (3, 3)))
|
||||
|
||||
def as_quat(self, normalize: bool = False) -> RotationQuat:
|
||||
return self
|
||||
|
||||
def apply(self, p: torch.Tensor) -> torch.Tensor:
|
||||
return _quat_rotation(self.normalized()._quats, p)
|
||||
|
||||
def invert(self) -> RotationQuat:
|
||||
return RotationQuat(_quat_invert(self._quats))
|
||||
|
||||
@property
|
||||
def tensor(self) -> torch.Tensor:
|
||||
return self._quats
|
||||
|
||||
def normalized(self) -> RotationQuat:
|
||||
return self if self._normalized else RotationQuat(self._quats, normalized=True)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Affine3D:
|
||||
trans: torch.Tensor
|
||||
@@ -222,6 +342,9 @@ class Affine3D:
|
||||
def as_matrix(self):
|
||||
return Affine3D(trans=self.trans, rot=self.rot.as_matrix())
|
||||
|
||||
def as_quat(self, normalize: bool = False):
|
||||
return Affine3D(trans=self.trans, rot=self.rot.as_quat(normalize))
|
||||
|
||||
def compose(self, other: "Affine3D", autoconvert: bool = False):
|
||||
rot = self.rot
|
||||
new_rot = (rot.convert_compose if autoconvert else rot.compose)(other.rot)
|
||||
@@ -271,6 +394,13 @@ class Affine3D:
|
||||
# Assume tensor 4x4 for backward compat with alphafold
|
||||
trans = t[..., :3, 3]
|
||||
rot = RotationMatrix(t[..., :3, :3])
|
||||
case 6:
|
||||
# Assume quaternion representation with real part = 1
|
||||
trans = t[..., -3:]
|
||||
rot = RotationQuat(F.pad(t[..., :3], (1, 0), value=1))
|
||||
case 7:
|
||||
trans = t[..., -3:]
|
||||
rot = RotationQuat(t[..., :4])
|
||||
case 12:
|
||||
trans = t[..., -3:]
|
||||
rot = RotationMatrix(t[..., :-3].unflatten(-1, (3, 3)))
|
||||
@@ -305,6 +435,62 @@ class Affine3D:
|
||||
return Affine3D.from_tensor(torch.cat([x.tensor for x in affines], dim=dim))
|
||||
|
||||
|
||||
def _quat_mult(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Multiply two quaternions.
|
||||
Usual torch rules for broadcasting apply.
|
||||
|
||||
Args:
|
||||
a: Quaternions as tensor of shape (..., 4), real part first.
|
||||
b: Quaternions as tensor of shape (..., 4), real part first.
|
||||
|
||||
Returns:
|
||||
The product of a and b, a tensor of quaternions shape (..., 4).
|
||||
"""
|
||||
aw, ax, ay, az = torch.unbind(a, -1)
|
||||
bw, bx, by, bz = torch.unbind(b, -1)
|
||||
ow = aw * bw - ax * bx - ay * by - az * bz
|
||||
ox = aw * bx + ax * bw + ay * bz - az * by
|
||||
oy = aw * by - ax * bz + ay * bw + az * bx
|
||||
oz = aw * bz + ax * by - ay * bx + az * bw
|
||||
return torch.stack((ow, ox, oy, oz), -1)
|
||||
|
||||
|
||||
def _quat_rotation(q: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Rotates p by quaternion q. Usual torch rules for broadcasting apply.
|
||||
|
||||
Args:
|
||||
q: Quaternions as tensor of shape (..., 4), real part first.
|
||||
p: Points as tensor of shape (..., 3)
|
||||
|
||||
Returns:
|
||||
The rotated version of p, of shape (..., 3)
|
||||
"""
|
||||
aw, ax, ay, az = torch.unbind(q, -1)
|
||||
bx, by, bz = torch.unbind(p, -1)
|
||||
# fmt: off
|
||||
ow = - ax * bx - ay * by - az * bz
|
||||
ox = aw * bx + ay * bz - az * by
|
||||
oy = aw * by - ax * bz + az * bx
|
||||
oz = aw * bz + ax * by - ay * bx
|
||||
# fmt: on
|
||||
q_mul_pts = torch.stack((ow, ox, oy, oz), -1)
|
||||
return _quat_mult(q_mul_pts, _quat_invert(q))[..., 1:]
|
||||
|
||||
|
||||
def _quat_invert(q: torch.Tensor):
|
||||
return q * torch.tensor([1, -1, -1, -1], device=q.device)
|
||||
|
||||
|
||||
def _sqrt_subgradient(x: torch.Tensor) -> torch.Tensor:
|
||||
# Returns torch.sqrt(torch.max(0, x)) but with a zero subgradient where x is 0.
|
||||
ret = torch.zeros_like(x)
|
||||
positive_mask = x > 0
|
||||
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
||||
return ret
|
||||
|
||||
|
||||
def _graham_schmidt(x_axis: torch.Tensor, xy_plane: torch.Tensor, eps: float = 1e-12):
|
||||
# A low eps here is necessary for good stability!
|
||||
with fp32_autocast_context(x_axis.device.type):
|
||||
|
||||
@@ -6,17 +6,21 @@ from typing import Any, ClassVar, Protocol, TypeVar
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from esm.utils.structure.protein_structure import (
|
||||
compute_affine_and_rmsd,
|
||||
)
|
||||
from esm.utils.structure.protein_structure import compute_affine_and_rmsd
|
||||
|
||||
|
||||
class Alignable(Protocol):
|
||||
atom37_positions: np.ndarray
|
||||
atom37_mask: np.ndarray
|
||||
# Trick to detect whether an object is a dataclass
|
||||
__dataclass_fields__: ClassVar[dict[str, Field[Any]]]
|
||||
|
||||
@property
|
||||
def atom37_positions(self) -> np.ndarray: # type: ignore
|
||||
pass
|
||||
|
||||
@property
|
||||
def atom37_mask(self) -> np.ndarray: # type: ignore
|
||||
pass
|
||||
|
||||
def __len__(self) -> int: ...
|
||||
|
||||
|
||||
|
||||
15
esm/utils/structure/atom_indexer.py
Normal file
15
esm/utils/structure/atom_indexer.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import numpy as np
|
||||
|
||||
from esm.utils.structure.protein_structure import index_by_atom_name
|
||||
|
||||
|
||||
class AtomIndexer:
|
||||
def __init__(self, structure, property: str, dim: int):
|
||||
self.structure = structure
|
||||
self.property = property
|
||||
self.dim = dim
|
||||
|
||||
def __getitem__(self, atom_names: str | list[str]) -> np.ndarray:
|
||||
return index_by_atom_name(
|
||||
getattr(self.structure, self.property), atom_names, self.dim
|
||||
)
|
||||
@@ -1,19 +1,148 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
from torch.amp import autocast # type: ignore
|
||||
|
||||
from esm.utils import residue_constants
|
||||
from esm.utils.misc import unbinpack
|
||||
from esm.utils.misc import binpack, unbinpack
|
||||
from esm.utils.structure.protein_structure import (
|
||||
compute_alignment_tensors,
|
||||
compute_gdt_ts_no_alignment,
|
||||
compute_rmsd_no_alignment,
|
||||
)
|
||||
|
||||
|
||||
def contact_precision(
|
||||
predictions: Tensor,
|
||||
targets: Tensor,
|
||||
src_lengths: Tensor | None = None,
|
||||
minsep: int = 6,
|
||||
maxsep: int | None = None,
|
||||
override_length: int | None = None, # for casp
|
||||
):
|
||||
"""Computes contact precisions.
|
||||
|
||||
For protein contact prediction, precision is measured for the top (L/K) highest confidence predictions,
|
||||
with L being the length of the protein sequence and K generally being equal to 1 or 5.
|
||||
|
||||
K = 5 measures the predictions of the very highest confidence contacts, while K = 1 is a more general measure
|
||||
over all relatively high confidence predictions.
|
||||
|
||||
Since there are roughly ~L true contacts in a protein, this is a reasonable cutoff.
|
||||
|
||||
|
||||
Args:
|
||||
predictions (Tensor): Tensor of probabilities of size (B, L, L)
|
||||
targets (Tensor): Tensor of true contacts of size (B, L, L)
|
||||
src_lengths (Tensor, optional): Lengths of each sample in the batch, if using variable lengths.
|
||||
If not provided, inferred from the size of the predictions.
|
||||
minsep (int): Minimum separation distance to consider. We often want to measure contacts at a
|
||||
certain range. Typical ranges are short [6, 12), medium [12, 24), and long [24, inf).
|
||||
maxsep (int, optional): Used in conjunction with minsep to specify a contact range. If not provided uses
|
||||
assumes no maximum range
|
||||
override_length (int, optional): Used for casp evaluation where sometimes the "true" length is not
|
||||
the same as the length of the input. Kept for posterity, we probably don't need this argument.
|
||||
"""
|
||||
if predictions.dim() == 2:
|
||||
predictions = predictions.unsqueeze(0)
|
||||
if targets.dim() == 2:
|
||||
targets = targets.unsqueeze(0)
|
||||
|
||||
# Check sizes
|
||||
if predictions.size() != targets.size():
|
||||
raise ValueError(
|
||||
f"Size mismatch. Received predictions of size {predictions.size()}, "
|
||||
f"targets of size {targets.size()}"
|
||||
)
|
||||
device = predictions.device
|
||||
|
||||
batch_size, seqlen, _ = predictions.size()
|
||||
|
||||
# Step 1) Construct a mask of size [B, L, L] to mask invalid contacts
|
||||
seqlen_range = torch.arange(seqlen, device=device)
|
||||
sep = seqlen_range.unsqueeze(0) - seqlen_range.unsqueeze(1)
|
||||
sep = sep.unsqueeze(0)
|
||||
# Mask contacts that are closer than minsep
|
||||
valid_mask = sep >= minsep
|
||||
# Mask contacts where target is negative (padding or unknown)
|
||||
valid_mask = valid_mask & (targets >= 0) # negative targets are invalid
|
||||
|
||||
# Mask contacts that are farther than maxsep, if provided
|
||||
if maxsep is not None:
|
||||
valid_mask &= sep < maxsep
|
||||
|
||||
if src_lengths is not None:
|
||||
# If the lengths of the individual sequences are provided, mask positions
|
||||
# that are farther than the end of the sequence.
|
||||
valid = seqlen_range.unsqueeze(0) < src_lengths.unsqueeze(1)
|
||||
valid_mask &= valid.unsqueeze(1) & valid.unsqueeze(2)
|
||||
else:
|
||||
src_lengths = torch.full([batch_size], seqlen, device=device, dtype=torch.long)
|
||||
|
||||
# Fill in the logit tensor with -inf for all invalid positions
|
||||
predictions = predictions.masked_fill(~valid_mask, float("-inf"))
|
||||
|
||||
# Step 2) Select the top half of the prediction (should be symmetric)
|
||||
x_ind, y_ind = np.triu_indices(seqlen, minsep)
|
||||
predictions_upper = predictions[:, x_ind, y_ind]
|
||||
targets_upper = targets[:, x_ind, y_ind]
|
||||
|
||||
# Step 3) Select the topk values in each batch where k = L (length of sequence)
|
||||
topk = seqlen if override_length is None else max(seqlen, override_length)
|
||||
# Indices are the indices into the predictions corresponding to the most confident predictions
|
||||
indices = predictions_upper.argsort(dim=-1, descending=True)[:, :topk]
|
||||
# topk_targets are the target values corresponding to the above indices
|
||||
topk_targets = targets_upper[torch.arange(batch_size).unsqueeze(1), indices]
|
||||
if topk_targets.size(1) < topk:
|
||||
# If there aren't enough targets, pad to the output.
|
||||
topk_targets = F.pad(topk_targets, [0, topk - topk_targets.size(1)])
|
||||
|
||||
# Step 4) Sum the accuracy at of the top-i predictions for i in 1, L
|
||||
# topk_targets => 1/0 true vs. false contact, sorted by confidence of prediction
|
||||
# cmumulative sum => Number of correct answers for the top-i predictions.
|
||||
cumulative_dist = topk_targets.type_as(predictions).cumsum(-1)
|
||||
|
||||
# Step 5) Find the gather indices. This should be P@(L / K) for varous values of K
|
||||
# The values will differ for each batch.
|
||||
gather_lengths = src_lengths.unsqueeze(1)
|
||||
if override_length is not None:
|
||||
gather_lengths = override_length * torch.ones_like(
|
||||
gather_lengths, device=device
|
||||
)
|
||||
|
||||
# This gets you (0.1 * L, 0.2 * L, 0.3 * L, etc.)
|
||||
gather_indices = (
|
||||
(torch.arange(0.1, 1.1, 0.1, device=device).unsqueeze(0) * gather_lengths).type(
|
||||
torch.long
|
||||
)
|
||||
- 1
|
||||
).clamp_min(0)
|
||||
|
||||
# Step 6) Gather the results and divide by the number of guesses to get the precision.
|
||||
binned_cumulative_dist = cumulative_dist.gather(1, gather_indices)
|
||||
binned_precisions = binned_cumulative_dist / (gather_indices + 1).type_as(
|
||||
binned_cumulative_dist
|
||||
)
|
||||
|
||||
# Select specific P@L/k. pl5 is index 1 b/c that corresponds to L * 0.2 in
|
||||
# gather_indices above
|
||||
pl5 = binned_precisions[:, 1]
|
||||
# pl2 = binned_precisions[:, 4]
|
||||
pl = binned_precisions[:, 9]
|
||||
# AUC is the integral wrt K of P@L/K for K in range(1, L)
|
||||
auc = binned_precisions.mean(-1)
|
||||
|
||||
return {"AUC": auc, "P@L": pl, "P@L5": pl5}
|
||||
|
||||
|
||||
def compute_lddt(
|
||||
all_atom_pred_pos: torch.Tensor,
|
||||
all_atom_positions: torch.Tensor,
|
||||
all_atom_mask: torch.Tensor,
|
||||
cutoff: float = 15.0,
|
||||
pairwise_all_atom_mask: torch.Tensor | None = None,
|
||||
cutoff: float | torch.Tensor = 15.0,
|
||||
eps: float = 1e-10,
|
||||
per_residue: bool = True,
|
||||
sequence_id: torch.Tensor | None = None,
|
||||
@@ -29,7 +158,8 @@ def compute_lddt(
|
||||
all_atom_pred_pos (Tensor[float], [(Nstates x) B x (L * Natoms x) 3]): Tensor of predicted positions
|
||||
all_atom_positions (Tensor[float], [B x (L * Natoms x) 3]): Tensor of true positions
|
||||
all_atom_mask (Tensor[float], [B x (L * Natoms)]): Tensor of masks, indicating whether an atom exists.
|
||||
cutoff (float): Max distance to score lddt over.
|
||||
pairwise_all_atom_mask (Tensor[float], [B x (L * Natoms x L * Natoms)], optional): Tensor of masks, indicating whether a pair of atoms should be considered in the LDDT calculation.
|
||||
cutoff (float): Max distance to score lddt over. This can either be a float, or a tensor of shape [B, L, L] to allow for per-residue cutoffs, e.g. if you want to use a different cutoff for nucleic acids.
|
||||
per_residue (bool): Whether to return per-residue or full-protein lddt.
|
||||
sequence_id (Tensor, optional): Sequence id tensor for binpacking. NOTE: only supported for lddt_ca calculations, not when Natoms is passed!
|
||||
|
||||
@@ -40,7 +170,7 @@ def compute_lddt(
|
||||
else:
|
||||
Tensor[float], [(Nstates x) B]
|
||||
"""
|
||||
n = all_atom_mask.shape[-2]
|
||||
all_atom_mask = all_atom_mask[..., None] # add a dimension for broadcasting
|
||||
dmat_true = torch.sqrt(
|
||||
eps
|
||||
+ torch.sum(
|
||||
@@ -57,22 +187,56 @@ def compute_lddt(
|
||||
dim=-1,
|
||||
)
|
||||
)
|
||||
dists_to_score = (
|
||||
(dmat_true < cutoff)
|
||||
* all_atom_mask
|
||||
* rearrange(all_atom_mask, "... a b -> ... b a")
|
||||
* (1.0 - torch.eye(n, device=all_atom_mask.device))
|
||||
)
|
||||
mask = all_atom_mask * rearrange(all_atom_mask, "... a b -> ... b a")
|
||||
if pairwise_all_atom_mask is not None:
|
||||
mask = mask * pairwise_all_atom_mask
|
||||
|
||||
if sequence_id is not None:
|
||||
# TODO(roshan): This will work for lddt_ca, but not for regular lddt
|
||||
# TODO: This will work for lddt_ca, but not for regular lddt
|
||||
# Problem is that regular lddt has natoms * nres scores, so would need to repeat this mask by natoms
|
||||
# Leaving for now because it won't fail silently so should be ook.
|
||||
seqid_mask = sequence_id[..., None] == sequence_id[..., None, :]
|
||||
dists_to_score = dists_to_score * seqid_mask.type_as(dists_to_score)
|
||||
mask = mask * seqid_mask.type_as(mask)
|
||||
|
||||
return compute_lddt_from_dmat(
|
||||
dmat_pred, dmat_true, mask, cutoff=cutoff, eps=eps, per_residue=per_residue
|
||||
)
|
||||
|
||||
|
||||
def compute_lddt_from_dmat(
|
||||
dmat_pred: torch.Tensor,
|
||||
dmat_true: torch.Tensor,
|
||||
pairwise_mask: torch.Tensor,
|
||||
cutoff: float | torch.Tensor = 15.0,
|
||||
eps: float = 1e-10,
|
||||
per_residue: bool = True,
|
||||
):
|
||||
"""
|
||||
Compute LDDT from pre-computed distance matrices.
|
||||
This is useful when you want to compute LDDT with multiple different masks or cutoffs, e.g. for different molecule types (protein, nucleic acid, etc.).
|
||||
|
||||
Args:
|
||||
dmat_pred (Tensor[float], [B x L x L]): Predicted distance matrix
|
||||
dmat_true (Tensor[float], [B x L x L]): True distance matrix
|
||||
pairwise_mask (Tensor[float], [B x L x L]): Pairwise mask indicating which pairs of atoms to consider
|
||||
cutoff (float): Max distance to score lddt over. This can either be a float, or a tensor of shape [B, L, L] to allow for per-residue cutoffs, e.g. if you want to use a different cutoff for nucleic acids.
|
||||
per_residue (bool): Whether to return per-residue or full-protein lddt.
|
||||
|
||||
Returns:
|
||||
LDDT Tensor:
|
||||
if per_residue:
|
||||
Tensor[float], [B x L]
|
||||
else:
|
||||
Tensor[float], [B]
|
||||
"""
|
||||
n = dmat_true.size(-1)
|
||||
dists_to_score = (
|
||||
(dmat_true < cutoff)
|
||||
* pairwise_mask
|
||||
* (1.0 - torch.eye(n, device=dmat_true.device))
|
||||
)
|
||||
|
||||
dist_l1 = torch.abs(dmat_true - dmat_pred)
|
||||
|
||||
score = (
|
||||
(dist_l1 < 0.5).type(dist_l1.dtype)
|
||||
+ (dist_l1 < 1.0).type(dist_l1.dtype)
|
||||
@@ -84,7 +248,6 @@ def compute_lddt(
|
||||
dims = (-1,) if per_residue else (-2, -1)
|
||||
norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims))
|
||||
score = norm * (eps + torch.sum(dists_to_score * score, dim=dims))
|
||||
|
||||
return score
|
||||
|
||||
|
||||
@@ -114,6 +277,56 @@ def compute_lddt_ca(
|
||||
)
|
||||
|
||||
|
||||
# NOTE(roshan): no_grad required for stack_variable_length_tensors apparently... let's revisit if we want to backprop
|
||||
@torch.no_grad()
|
||||
@autocast("cuda", enabled=False)
|
||||
def compute_rmsd(
|
||||
mobile: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
atom_exists_mask: torch.Tensor | None = None,
|
||||
sequence_id: torch.Tensor | None = None,
|
||||
reduction: str = "batch",
|
||||
):
|
||||
"""
|
||||
Compute RMSD between two batches of structures with support for masking invalid atoms using PyTorch.
|
||||
|
||||
Args:
|
||||
- mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3)
|
||||
- target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3)
|
||||
- atom_exists_mask (torch.Tensor, optional): Mask for Whether an atom exists of shape (B, N)
|
||||
- sequence_id (torch.Tensor, optional): Sequence id tensor for binpacking.
|
||||
- reduction (str): One of "batch", "per_sample", "per_residue".
|
||||
|
||||
Returns:
|
||||
If reduction == "batch":
|
||||
(torch.Tensor): 0-dim, Average Root Mean Square Deviation between the structures for each batch
|
||||
If reduction == "per_sample":
|
||||
(torch.Tensor): (B,)-dim, Root Mean Square Deviation between the structures for each batch
|
||||
If reduction == "per_residue":
|
||||
(torch.Tensor): (B, N)-dim, Root Mean Square Deviation between the structures for residue in the batch
|
||||
"""
|
||||
|
||||
(centered_mobile, _, centered_target, _, rotation_matrix, num_valid_atoms) = (
|
||||
compute_alignment_tensors(
|
||||
mobile=mobile,
|
||||
target=target,
|
||||
atom_exists_mask=atom_exists_mask,
|
||||
sequence_id=sequence_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Apply transformation to centered structure
|
||||
rotated_mobile = torch.matmul(centered_mobile, rotation_matrix)
|
||||
|
||||
# Compute rmsd for centered structures
|
||||
rmsd = compute_rmsd_no_alignment(
|
||||
rotated_mobile, centered_target, num_valid_atoms, reduction=reduction
|
||||
)
|
||||
if reduction == "per_residue" and sequence_id is not None:
|
||||
rmsd = binpack(rmsd, sequence_id, pad_value=0)
|
||||
return rmsd
|
||||
|
||||
|
||||
def compute_gdt_ts(
|
||||
mobile: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
|
||||
469
esm/utils/structure/mmcif_parsing.py
Normal file
469
esm/utils/structure/mmcif_parsing.py
Normal file
@@ -0,0 +1,469 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import io
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Union
|
||||
|
||||
import biotite.structure as bs
|
||||
import biotite.structure.io.pdbx as pdbx
|
||||
|
||||
from esm.utils import residue_constants
|
||||
|
||||
# Define PathOrBuffer for the opensource version
|
||||
PathOrBuffer = Union[str, os.PathLike, io.StringIO]
|
||||
|
||||
|
||||
class NoProteinError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Residue:
|
||||
residue_number: int | None = None
|
||||
insertion_code: str = ""
|
||||
hetflag: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class MmcifHeader:
|
||||
release_date: datetime | None = None
|
||||
resolution: float | None = None
|
||||
structure_method: str = "UNKNOWN"
|
||||
|
||||
|
||||
class MmcifWrapper:
|
||||
def __init__(self, id: str | None = None):
|
||||
self.id: str = id or ""
|
||||
self.raw: pdbx.CIFFile | None = None
|
||||
self.structure: bs.AtomArray
|
||||
self.header: MmcifHeader = MmcifHeader()
|
||||
self.entities: dict[int, list[str]] = {}
|
||||
self.chain_to_seqres: dict[str, str] = {}
|
||||
self.seqres_to_structure: dict[str, dict[int, Residue]] = {}
|
||||
|
||||
@classmethod
|
||||
def read(cls, path: PathOrBuffer, id: str | None = None) -> MmcifWrapper:
|
||||
obj = cls(id=id)
|
||||
obj._load(path)
|
||||
return obj
|
||||
|
||||
def _load(self, path: PathOrBuffer, fileid: str | None = None):
|
||||
"""Load mmCIF data from file."""
|
||||
self.raw = pdbx.CIFFile.read(path)
|
||||
|
||||
self._parse_structure()
|
||||
self._parse_header()
|
||||
self._parse_entities()
|
||||
self._parse_sequences()
|
||||
|
||||
def _parse_structure(self):
|
||||
"""Parse the atomic structure from mmCIF."""
|
||||
try:
|
||||
structure = pdbx.get_structure(self.raw, model=1)
|
||||
if structure is None or not isinstance(structure, bs.AtomArray):
|
||||
raise NoProteinError("No structure found in mmCIF file")
|
||||
if len(structure) == 0:
|
||||
raise NoProteinError("Empty structure in mmCIF file")
|
||||
self.structure = structure
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse structure: {e}")
|
||||
|
||||
def _parse_header(self):
|
||||
"""Parse header information from mmCIF."""
|
||||
if not self.raw:
|
||||
return
|
||||
|
||||
try:
|
||||
# Get the first (and usually only) block
|
||||
block = self.raw.block
|
||||
|
||||
# Parse release date
|
||||
if "pdbx_database_status" in block:
|
||||
status_cat = block["pdbx_database_status"]
|
||||
if "recvd_initial_deposition_date" in status_cat:
|
||||
date_str = status_cat["recvd_initial_deposition_date"].as_item()
|
||||
if date_str and date_str != "?":
|
||||
try:
|
||||
self.header.release_date = datetime.strptime(
|
||||
date_str, "%Y-%m-%d"
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Parse resolution
|
||||
if "refine" in block:
|
||||
refine_cat = block["refine"]
|
||||
if "ls_d_res_high" in refine_cat:
|
||||
res_str = refine_cat["ls_d_res_high"].as_item()
|
||||
if res_str and res_str != "?":
|
||||
try:
|
||||
self.header.resolution = float(res_str)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Parse structure method
|
||||
if "exptl" in block:
|
||||
exptl_cat = block["exptl"]
|
||||
if "method" in exptl_cat:
|
||||
method = exptl_cat["method"].as_item()
|
||||
if method and method != "?":
|
||||
self.header.structure_method = method.upper()
|
||||
|
||||
except Exception:
|
||||
# If parsing fails, keep default values
|
||||
pass
|
||||
|
||||
def _parse_entities(self):
|
||||
"""Parse entity information and map to chains."""
|
||||
if not self.raw:
|
||||
return
|
||||
|
||||
try:
|
||||
block = self.raw.block
|
||||
|
||||
# Parse entity information
|
||||
if "entity" in block:
|
||||
entity_cat = block["entity"]
|
||||
entity_ids = entity_cat["id"].as_array(str)
|
||||
entity_types = entity_cat["type"].as_array(str)
|
||||
|
||||
# Initialize entities dict with all entities (not just polymers)
|
||||
for i, (entity_id, entity_type) in enumerate(
|
||||
zip(entity_ids, entity_types)
|
||||
):
|
||||
self.entities[int(entity_id)] = []
|
||||
|
||||
# Map polymer chains to entities using entity_poly
|
||||
if "entity_poly" in block:
|
||||
poly_cat = block["entity_poly"]
|
||||
entity_ids = poly_cat["entity_id"].as_array(str)
|
||||
chain_lists = poly_cat["pdbx_strand_id"].as_array(str)
|
||||
|
||||
for entity_id, chain_list in zip(entity_ids, chain_lists):
|
||||
entity_id = int(entity_id)
|
||||
# Chain list is comma-separated
|
||||
chains = [c.strip() for c in chain_list.split(",") if c.strip()]
|
||||
if entity_id in self.entities:
|
||||
self.entities[entity_id] = chains
|
||||
|
||||
# Map non-polymer chains using struct_asym for entities not covered by entity_poly
|
||||
if "struct_asym" in block:
|
||||
asym_cat = block["struct_asym"]
|
||||
asym_ids = asym_cat["id"].as_array(str)
|
||||
entity_ids = asym_cat["entity_id"].as_array(str)
|
||||
|
||||
for asym_id, entity_id in zip(asym_ids, entity_ids):
|
||||
entity_id = int(entity_id)
|
||||
# Only add if entity exists but has no chains yet (non-polymer entities)
|
||||
if entity_id in self.entities and not self.entities[entity_id]:
|
||||
self.entities[entity_id].append(asym_id)
|
||||
|
||||
except Exception:
|
||||
# If parsing fails, try to infer from structure
|
||||
if (
|
||||
self.structure
|
||||
and hasattr(self.structure, "chain_id")
|
||||
and self.structure.chain_id is not None
|
||||
and hasattr(self.structure.chain_id, "__iter__")
|
||||
):
|
||||
chain_ids = list(set(self.structure.chain_id))
|
||||
self.entities = {1: chain_ids}
|
||||
|
||||
def _parse_sequences(self):
|
||||
"""Parse sequence information from mmCIF."""
|
||||
if not self.raw:
|
||||
return
|
||||
|
||||
block = self.raw.block
|
||||
|
||||
# Parse polymer sequences
|
||||
if "entity_poly" in block:
|
||||
poly_cat = block["entity_poly"]
|
||||
entity_ids = poly_cat["entity_id"].as_array(str)
|
||||
sequences = poly_cat["pdbx_seq_one_letter_code_can"].as_array(str)
|
||||
chain_lists = poly_cat["pdbx_strand_id"].as_array(str)
|
||||
|
||||
for entity_id, sequence, chain_list in zip(
|
||||
entity_ids, sequences, chain_lists
|
||||
):
|
||||
# Clean up sequence (remove whitespace and newlines)
|
||||
clean_seq = "".join(sequence.split())
|
||||
chains = [c.strip() for c in chain_list.split(",") if c.strip()]
|
||||
|
||||
for chain_id in chains:
|
||||
self.chain_to_seqres[chain_id] = clean_seq
|
||||
|
||||
# Parse sequence to structure mapping
|
||||
if "pdbx_poly_seq_scheme" in block:
|
||||
seq_cat = block["pdbx_poly_seq_scheme"]
|
||||
asym_ids = seq_cat["asym_id"].as_array(str) # Internal chain IDs
|
||||
seq_positions = seq_cat["seq_id"].as_array(str)
|
||||
auth_seq_nums = seq_cat["auth_seq_num"].as_array(str)
|
||||
ins_codes = (
|
||||
seq_cat["pdb_ins_code"].as_array(str)
|
||||
if "pdb_ins_code" in seq_cat
|
||||
else [""] * len(asym_ids)
|
||||
)
|
||||
hetflags = (
|
||||
seq_cat["hetflag"].as_array(str)
|
||||
if "hetflag" in seq_cat
|
||||
else ["N"] * len(asym_ids)
|
||||
)
|
||||
|
||||
# Get author chain IDs if available
|
||||
auth_chain_ids = (
|
||||
seq_cat["pdb_strand_id"].as_array(str)
|
||||
if "pdb_strand_id" in seq_cat
|
||||
else asym_ids # Fallback to internal IDs
|
||||
)
|
||||
|
||||
# Build mapping from internal chain ID to author chain ID
|
||||
asym_to_auth_mapping = {}
|
||||
for asym_id, auth_id in zip(asym_ids, auth_chain_ids):
|
||||
asym_to_auth_mapping[asym_id] = auth_id
|
||||
|
||||
# Group by internal chain ID first, then map to author chain ID
|
||||
chain_data = {}
|
||||
for asym_id, seq_pos, auth_seq, ins_code, hetflag in zip(
|
||||
asym_ids, seq_positions, auth_seq_nums, ins_codes, hetflags
|
||||
):
|
||||
if asym_id not in chain_data:
|
||||
chain_data[asym_id] = {}
|
||||
|
||||
try:
|
||||
seq_index = int(seq_pos) - 1 # Convert to 0-based indexing
|
||||
res_num = int(auth_seq) if auth_seq != "?" else None
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if res_num is not None:
|
||||
# Convert mmCIF "." and "?" to empty string
|
||||
clean_ins_code = "" if ins_code in [".", "?"] else ins_code
|
||||
else:
|
||||
clean_ins_code = ""
|
||||
res_num = None
|
||||
|
||||
is_het = hetflag.upper() == "Y" # type: ignore
|
||||
chain_data[asym_id][seq_index] = Residue(
|
||||
residue_number=res_num,
|
||||
insertion_code=clean_ins_code, # type: ignore
|
||||
hetflag=is_het,
|
||||
)
|
||||
|
||||
# Handle cases where multiple residues have the same auth_seq_num
|
||||
# by adjusting residue numbers to be unique within each chain
|
||||
for asym_id, residue_data in chain_data.items():
|
||||
# Check if there are duplicate residue numbers in this chain
|
||||
positions_with_same_num = {}
|
||||
for seq_idx, res_at_pos in residue_data.items():
|
||||
if res_at_pos.residue_number is not None:
|
||||
res_num = res_at_pos.residue_number
|
||||
if res_num not in positions_with_same_num:
|
||||
positions_with_same_num[res_num] = []
|
||||
positions_with_same_num[res_num].append(seq_idx)
|
||||
|
||||
# Fix duplicate residue numbers by making them sequential
|
||||
for res_num, seq_indices in positions_with_same_num.items():
|
||||
if len(seq_indices) > 1:
|
||||
# Multiple residues have the same residue number
|
||||
# Make them sequential starting from the original number
|
||||
seq_indices.sort() # Ensure consistent ordering
|
||||
for i, seq_idx in enumerate(seq_indices):
|
||||
original_pos = residue_data[seq_idx]
|
||||
new_pos = Residue(
|
||||
residue_number=res_num + i,
|
||||
insertion_code=original_pos.insertion_code,
|
||||
hetflag=original_pos.hetflag,
|
||||
)
|
||||
residue_data[seq_idx] = new_pos
|
||||
|
||||
# Create ordered mappings using author chain IDs
|
||||
for asym_id in chain_data:
|
||||
auth_chain_id = asym_to_auth_mapping.get(asym_id, asym_id)
|
||||
if auth_chain_id in self.chain_to_seqres:
|
||||
seq_len = len(self.chain_to_seqres[auth_chain_id])
|
||||
ordered_mapping = {}
|
||||
|
||||
for i in range(seq_len):
|
||||
if i in chain_data[asym_id]:
|
||||
ordered_mapping[i] = chain_data[asym_id][i]
|
||||
else:
|
||||
# Missing residue - no structure coordinates
|
||||
ordered_mapping[i] = Residue(
|
||||
residue_number=None, insertion_code="", hetflag=False
|
||||
)
|
||||
|
||||
self.seqres_to_structure[auth_chain_id] = ordered_mapping
|
||||
else:
|
||||
# Handle case where auth_chain_id is not in chain_to_seqres
|
||||
# This can happen if the chain is not a polymer or if there's a parsing issue
|
||||
# Create a basic mapping based on the chain_data
|
||||
if chain_data[asym_id]:
|
||||
# Sort by sequence index to create ordered mapping
|
||||
sorted_indices = sorted(chain_data[asym_id].keys())
|
||||
ordered_mapping = {}
|
||||
for i, seq_idx in enumerate(sorted_indices):
|
||||
ordered_mapping[i] = chain_data[asym_id][seq_idx]
|
||||
self.seqres_to_structure[auth_chain_id] = ordered_mapping
|
||||
|
||||
# Ensure all chains have complete mappings
|
||||
for chain_id in self.chain_to_seqres:
|
||||
if chain_id not in self.seqres_to_structure:
|
||||
seq_len = len(self.chain_to_seqres[chain_id])
|
||||
self.seqres_to_structure[chain_id] = {
|
||||
i: Residue(residue_number=None, insertion_code="", hetflag=False)
|
||||
for i in range(seq_len)
|
||||
}
|
||||
else:
|
||||
# Fill in any missing indices
|
||||
seq_len = len(self.chain_to_seqres[chain_id])
|
||||
mapping = self.seqres_to_structure[chain_id]
|
||||
for i in range(seq_len):
|
||||
if i not in mapping:
|
||||
mapping[i] = Residue(
|
||||
residue_number=None, insertion_code="", hetflag=False
|
||||
)
|
||||
|
||||
# Fallback: create basic mappings from structure for missing chains
|
||||
if (
|
||||
self.structure
|
||||
and hasattr(self.structure, "chain_id")
|
||||
and self.structure.chain_id is not None
|
||||
and hasattr(self.structure.chain_id, "__iter__")
|
||||
):
|
||||
for chain_id in set(self.structure.chain_id):
|
||||
if chain_id not in self.seqres_to_structure:
|
||||
chain_structure = self.structure[
|
||||
self.structure.chain_id == chain_id
|
||||
]
|
||||
if (
|
||||
hasattr(chain_structure, "res_id")
|
||||
and chain_structure.res_id is not None
|
||||
and hasattr(chain_structure.res_id, "__iter__")
|
||||
):
|
||||
residue_ids = list(set(chain_structure.res_id))
|
||||
residue_ids.sort()
|
||||
|
||||
self.seqres_to_structure[chain_id] = {
|
||||
i: Residue(
|
||||
residue_number=res_id, insertion_code="", hetflag=False
|
||||
)
|
||||
for i, res_id in enumerate(residue_ids)
|
||||
}
|
||||
|
||||
def _parse_nonpoly_from_mmcif(self) -> dict[tuple, bs.AtomArray]:
|
||||
"""Parse non-polymer coordinates from mmCIF block data."""
|
||||
nonpoly_coords = {}
|
||||
|
||||
# Get non-polymer entities from the mmCIF block
|
||||
assert self.raw is not None
|
||||
block = self.raw.block
|
||||
nonpoly_entities = set()
|
||||
|
||||
# Find non-polymer entities
|
||||
if "entity" in block:
|
||||
entity_cat = block["entity"]
|
||||
entity_ids = entity_cat["id"].as_array(str)
|
||||
entity_types = entity_cat["type"].as_array(str)
|
||||
|
||||
for entity_id, entity_type in zip(entity_ids, entity_types):
|
||||
if entity_type.upper() in ["NON-POLYMER", "WATER", "BRANCHED"]:
|
||||
nonpoly_entities.add(entity_id)
|
||||
|
||||
# Map entities to chains for non-polymers
|
||||
entity_to_chains = {}
|
||||
if "pdbx_entity_nonpoly" in block:
|
||||
nonpoly_cat = block["pdbx_entity_nonpoly"]
|
||||
entity_ids = nonpoly_cat["entity_id"].as_array(str)
|
||||
comp_ids = nonpoly_cat["comp_id"].as_array(str)
|
||||
|
||||
for entity_id, comp_id in zip(entity_ids, comp_ids):
|
||||
if entity_id in nonpoly_entities:
|
||||
entity_to_chains[entity_id] = comp_id
|
||||
|
||||
# Get atom site information for non-polymers
|
||||
if "atom_site" in block:
|
||||
atom_cat = block["atom_site"]
|
||||
atom_chain_ids = atom_cat["label_asym_id"].as_array(str)
|
||||
atom_entity_ids = atom_cat["label_entity_id"].as_array(str)
|
||||
atom_comp_ids = atom_cat["label_comp_id"].as_array(str)
|
||||
|
||||
# Group non-polymer atoms by entity and chain
|
||||
nonpoly_atom_groups = {}
|
||||
for i, (chain_id, entity_id, comp_id) in enumerate(
|
||||
zip(atom_chain_ids, atom_entity_ids, atom_comp_ids)
|
||||
):
|
||||
if entity_id in nonpoly_entities:
|
||||
key = (comp_id, chain_id)
|
||||
if key not in nonpoly_atom_groups:
|
||||
nonpoly_atom_groups[key] = []
|
||||
nonpoly_atom_groups[key].append(i)
|
||||
|
||||
# Extract coordinates for each non-polymer group
|
||||
for (comp_id, chain_id), atom_indices in nonpoly_atom_groups.items():
|
||||
# Match atoms by comparing chain_id and residue name
|
||||
structure_mask = (self.structure.chain_id == chain_id) & (
|
||||
self.structure.res_name == comp_id
|
||||
)
|
||||
|
||||
if structure_mask.any():
|
||||
nonpoly_array = self.structure[structure_mask]
|
||||
if (
|
||||
isinstance(nonpoly_array, (bs.AtomArray, bs.AtomArrayStack))
|
||||
and len(nonpoly_array) > 0
|
||||
):
|
||||
nonpoly_coords[(comp_id, chain_id)] = nonpoly_array
|
||||
|
||||
return nonpoly_coords
|
||||
|
||||
def _parse_nonpoly_fallback(self) -> dict[tuple, bs.AtomArray]:
|
||||
"""Fallback method to extract heteroatoms directly from structure."""
|
||||
nonpoly_coords = {}
|
||||
|
||||
if not (self.structure and hasattr(self.structure, "chain_id")):
|
||||
return nonpoly_coords
|
||||
|
||||
# Create set of standard residues from residue_constants
|
||||
standard_residues = set(residue_constants.resnames[:-1]) # Exclude 'UNK'
|
||||
standard_residues.update({"A", "C", "G", "T", "U"}) # Add nucleic acids
|
||||
|
||||
if hasattr(self.structure, "chain_id") and self.structure.chain_id is not None:
|
||||
for chain_id in set(self.structure.chain_id):
|
||||
chain_structure = self.structure[self.structure.chain_id == chain_id]
|
||||
|
||||
# Find non-standard residues
|
||||
if (
|
||||
hasattr(chain_structure, "res_name")
|
||||
and chain_structure.res_name is not None
|
||||
and hasattr(chain_structure.res_name, "__iter__")
|
||||
):
|
||||
for res_name in set(chain_structure.res_name):
|
||||
if res_name not in standard_residues:
|
||||
res_mask = (chain_structure.chain_id == chain_id) & (
|
||||
chain_structure.res_name == res_name
|
||||
)
|
||||
if res_mask.any() and isinstance(
|
||||
chain_structure, (bs.AtomArray, bs.AtomArrayStack)
|
||||
):
|
||||
nonpoly_array = chain_structure[res_mask]
|
||||
nonpoly_coords[(res_name, chain_id)] = nonpoly_array
|
||||
|
||||
return nonpoly_coords
|
||||
|
||||
@functools.cached_property
|
||||
def non_polymer_coords(self) -> dict[tuple, bs.AtomArray]:
|
||||
"""
|
||||
Extract non-polymer coordinates (ligands, cofactors, etc.) from mmCIF structure.
|
||||
|
||||
Returns a dictionary mapping (nonpolymer_info, chain_id) tuples to AtomArrays.
|
||||
"""
|
||||
if not self.structure or not self.raw:
|
||||
return {}
|
||||
|
||||
try:
|
||||
return self._parse_nonpoly_from_mmcif()
|
||||
except Exception:
|
||||
return self._parse_nonpoly_fallback()
|
||||
@@ -1,10 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import warnings
|
||||
from dataclasses import asdict, dataclass, replace
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Sequence, TypeVar, Union
|
||||
from typing import Any, Mapping, Sequence
|
||||
|
||||
import biotite.structure as bs
|
||||
import brotli
|
||||
@@ -12,51 +13,36 @@ import msgpack
|
||||
import msgpack_numpy
|
||||
import numpy as np
|
||||
import torch
|
||||
from Bio.Data import PDBData
|
||||
from biotite.application.dssp import DsspApp
|
||||
from biotite.database import rcsb
|
||||
from biotite.structure.io.pdb import PDBFile
|
||||
from cloudpathlib import CloudPath
|
||||
from scipy.spatial import ConvexHull
|
||||
from scipy.spatial.distance import pdist, squareform
|
||||
from torch import Tensor
|
||||
from biotite.structure.io.pdbx import CIFCategory, CIFColumn, CIFData, CIFFile
|
||||
from biotite.structure.io.pdbx import set_structure as set_structure_pdbx
|
||||
from scipy.spatial import ConvexHull, KDTree
|
||||
from scipy.spatial.distance import cdist, pdist, squareform
|
||||
|
||||
from esm.utils import residue_constants as RC
|
||||
from esm.utils.constants import esm3 as C
|
||||
from esm.utils import residue_constants
|
||||
from esm.utils.misc import slice_python_object_as_numpy
|
||||
from esm.utils.structure.affine3d import Affine3D
|
||||
from esm.utils.structure.aligner import Aligner
|
||||
from esm.utils.structure.metrics import compute_lddt_ca
|
||||
from esm.utils.structure.atom_indexer import AtomIndexer
|
||||
from esm.utils.structure.metrics import compute_gdt_ts, compute_lddt_ca
|
||||
from esm.utils.structure.mmcif_parsing import MmcifWrapper, Residue
|
||||
from esm.utils.structure.normalize_coordinates import (
|
||||
apply_frame_to_coords,
|
||||
get_protein_normalization_frame,
|
||||
normalize_coordinates,
|
||||
)
|
||||
from esm.utils.structure.protein_structure import index_by_atom_name
|
||||
from esm.utils.types import PathOrBuffer
|
||||
|
||||
msgpack_numpy.patch()
|
||||
|
||||
CHAIN_ID_CONST = "A"
|
||||
|
||||
|
||||
ArrayOrTensor = TypeVar("ArrayOrTensor", np.ndarray, Tensor)
|
||||
PathLike = Union[str, Path, CloudPath]
|
||||
PathOrBuffer = Union[PathLike, io.StringIO]
|
||||
|
||||
|
||||
def index_by_atom_name(
|
||||
atom37: ArrayOrTensor, atom_names: str | list[str], dim: int = -2
|
||||
) -> ArrayOrTensor:
|
||||
squeeze = False
|
||||
if isinstance(atom_names, str):
|
||||
atom_names = [atom_names]
|
||||
squeeze = True
|
||||
indices = [RC.atom_order[atom_name] for atom_name in atom_names]
|
||||
dim = dim % atom37.ndim
|
||||
index = tuple(slice(None) if dim != i else indices for i in range(atom37.ndim))
|
||||
result = atom37[index] # type: ignore
|
||||
if squeeze:
|
||||
result = result.squeeze(dim)
|
||||
return result
|
||||
def _num_non_null_residues(seqres_to_structure_chain: Mapping[int, Residue]) -> int:
|
||||
return sum(
|
||||
residue.residue_number is not None
|
||||
for residue in seqres_to_structure_chain.values()
|
||||
)
|
||||
|
||||
|
||||
def infer_CB(C, N, Ca, L: float = 1.522, A: float = 1.927, D: float = -2.143):
|
||||
@@ -78,34 +64,84 @@ def infer_CB(C, N, Ca, L: float = 1.522, A: float = 1.927, D: float = -2.143):
|
||||
return Ca + sum([m * d for m, d in zip(m, d)])
|
||||
|
||||
|
||||
class AtomIndexer:
|
||||
def __init__(self, structure: ProteinChain, property: str, dim: int):
|
||||
self.structure = structure
|
||||
self.property = property
|
||||
self.dim = dim
|
||||
def chain_to_ndarray(
|
||||
atom_array: bs.AtomArray, mmcif: MmcifWrapper, chain_id: str, is_predicted=False
|
||||
):
|
||||
entity_id = None
|
||||
for entity, chains in mmcif.entities.items():
|
||||
if chain_id in chains:
|
||||
entity_id = entity
|
||||
num_res = len(mmcif.chain_to_seqres[chain_id])
|
||||
sequence = mmcif.chain_to_seqres[chain_id]
|
||||
|
||||
def __getitem__(self, atom_names: str | list[str]) -> np.ndarray:
|
||||
return index_by_atom_name(
|
||||
getattr(self.structure, self.property), atom_names, self.dim
|
||||
)
|
||||
atom_positions = np.full([num_res, residue_constants.atom_type_num, 3], np.nan)
|
||||
atom_mask = np.full([num_res, residue_constants.atom_type_num], False, dtype=bool)
|
||||
residue_index = np.full([num_res], -1, dtype=np.int64)
|
||||
insertion_code = np.full([num_res], "", dtype="<U4")
|
||||
|
||||
confidence = np.ones([num_res], dtype=np.float32)
|
||||
|
||||
for res_index in range(num_res):
|
||||
chain = atom_array[atom_array.chain_id == chain_id]
|
||||
assert isinstance(chain, bs.AtomArray)
|
||||
res_at_position = mmcif.seqres_to_structure[chain_id][res_index]
|
||||
|
||||
if res_at_position.residue_number is None:
|
||||
continue
|
||||
|
||||
residue_index[res_index] = res_at_position.residue_number
|
||||
insertion_code[res_index] = res_at_position.insertion_code
|
||||
res = chain[
|
||||
(chain.res_id == res_at_position.residue_number)
|
||||
& (chain.ins_code == res_at_position.insertion_code)
|
||||
& (chain.hetero == res_at_position.hetflag)
|
||||
]
|
||||
assert isinstance(res, bs.AtomArray)
|
||||
|
||||
# Atom level features
|
||||
for atom in res:
|
||||
atom_name = atom.atom_name
|
||||
if atom_name == "SE" and atom.res_name == "MSE":
|
||||
# Put the coords of the selenium atom in the sulphur column
|
||||
atom_name = "SD"
|
||||
|
||||
if atom_name in residue_constants.atom_order:
|
||||
atom_positions[res_index, residue_constants.atom_order[atom_name]] = (
|
||||
atom.coord
|
||||
)
|
||||
atom_mask[res_index, residue_constants.atom_order[atom_name]] = True
|
||||
if is_predicted and atom_name == "CA":
|
||||
confidence[res_index] = atom.b_factor
|
||||
|
||||
assert all(sequence), "Some residue name was not specified correctly"
|
||||
return (
|
||||
sequence,
|
||||
atom_positions,
|
||||
atom_mask,
|
||||
residue_index,
|
||||
insertion_code,
|
||||
confidence,
|
||||
entity_id,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class ProteinChain:
|
||||
"""Dataclass with atom37 representation of a single protein chain."""
|
||||
|
||||
id: str
|
||||
sequence: str
|
||||
chain_id: str # author chain id
|
||||
chain_id: str # author chain id - mutable
|
||||
entity_id: int | None
|
||||
residue_index: np.ndarray
|
||||
insertion_code: np.ndarray
|
||||
atom37_positions: np.ndarray
|
||||
atom37_mask: np.ndarray
|
||||
confidence: np.ndarray
|
||||
mmcif: MmcifWrapper | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.atom37_mask = self.atom37_mask.astype(bool)
|
||||
assert self.atom37_mask.dtype == bool, self.atom37_mask.dtype
|
||||
assert self.atom37_positions.shape[0] == len(self.sequence), (
|
||||
self.atom37_positions.shape,
|
||||
len(self.sequence),
|
||||
@@ -152,10 +188,10 @@ class ProteinChain:
|
||||
chain_id="A" if self.chain_id is None else self.chain_id,
|
||||
res_id=res_idx,
|
||||
ins_code=ins_code,
|
||||
res_name=RC.restype_1to3.get(res_name, "UNK"),
|
||||
res_name=residue_constants.restype_1to3.get(res_name, "UNK"),
|
||||
hetero=False,
|
||||
atom_name=RC.atom_types[i],
|
||||
element=RC.atom_types[i][0],
|
||||
atom_name=residue_constants.atom_types[i],
|
||||
element=residue_constants.atom_types[i][0],
|
||||
b_factor=conf,
|
||||
)
|
||||
atoms.append(atom)
|
||||
@@ -182,18 +218,20 @@ class ProteinChain:
|
||||
# hard coded to as we currently only support single chain structures
|
||||
chain_id=CHAIN_ID_CONST,
|
||||
res_id=res_idx + 1,
|
||||
res_name=RC.restype_1to3.get(res_name, "UNK"),
|
||||
res_name=residue_constants.restype_1to3.get(res_name, "UNK"),
|
||||
hetero=False,
|
||||
atom_name=RC.atom_types[i],
|
||||
element=RC.atom_types[i][0],
|
||||
atom_name=residue_constants.atom_types[i],
|
||||
element=residue_constants.atom_types[i][0],
|
||||
b_factor=conf,
|
||||
)
|
||||
atoms.append(atom)
|
||||
return bs.array(atoms)
|
||||
|
||||
def __getitem__(self, idx: int | list[int] | slice | np.ndarray):
|
||||
def __getitem__(self, idx: int | list[int] | slice | np.ndarray | torch.Tensor):
|
||||
if isinstance(idx, int):
|
||||
idx = [idx]
|
||||
if isinstance(idx, torch.Tensor):
|
||||
idx = idx.cpu().numpy()
|
||||
|
||||
sequence = slice_python_object_as_numpy(self.sequence, idx)
|
||||
return replace(
|
||||
@@ -216,17 +254,6 @@ class ProteinChain:
|
||||
np.fill_diagonal(contacts, -1)
|
||||
return contacts
|
||||
|
||||
def to_structure_encoder_inputs(
|
||||
self, should_normalize_coordinates: bool = True
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
coords = torch.tensor(self.atom37_positions, dtype=torch.float32)
|
||||
plddt = torch.tensor(self.confidence, dtype=torch.float32)
|
||||
residue_index = torch.tensor(self.residue_index, dtype=torch.long)
|
||||
|
||||
if should_normalize_coordinates:
|
||||
coords = normalize_coordinates(coords)
|
||||
return coords.unsqueeze(0), plddt.unsqueeze(0), residue_index.unsqueeze(0)
|
||||
|
||||
def to_pdb(self, path: PathOrBuffer, include_insertions: bool = True):
|
||||
"""Dssp works better w/o insertions."""
|
||||
f = PDBFile()
|
||||
@@ -242,11 +269,77 @@ class ProteinChain:
|
||||
buf.seek(0)
|
||||
return buf.read()
|
||||
|
||||
def to_mmcif(self, path: PathOrBuffer):
|
||||
f = CIFFile()
|
||||
set_structure_pdbx(f, self.atom_array, data_block=self.id)
|
||||
|
||||
# incantations molstar needs to render pLDDT / confidence onto
|
||||
# the structure with "alphafold-view"
|
||||
f.block["ma_qa_metric"] = CIFCategory(
|
||||
name="ma_qa_metric",
|
||||
columns={
|
||||
"id": CIFColumn(data=CIFData(array=np.array([1, 2]), dtype=np.int64)),
|
||||
"mode": CIFColumn(
|
||||
data=CIFData(array=np.array(["global", "local"]), dtype=np.str_)
|
||||
),
|
||||
"name": CIFColumn(
|
||||
data=CIFData(array=np.array(["pLDDT", "pLDDT"]), dtype=np.str_)
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
# table is a duplicate of data already in the atom array, but
|
||||
# needed by molstar to render pLDDT / confidence
|
||||
resid_pldd_table = {
|
||||
# hard coded to as we currently only support single chain structures
|
||||
"label_asym_id": CIFColumn(
|
||||
data=CIFData(
|
||||
array=[CHAIN_ID_CONST] * len(self.residue_index), dtype=np.str_
|
||||
)
|
||||
),
|
||||
"label_comp_id": CIFColumn(
|
||||
data=CIFData(
|
||||
array=[
|
||||
residue_constants.restype_1to3.get(c, "UNK")
|
||||
for c in self.sequence
|
||||
],
|
||||
dtype=np.str_,
|
||||
)
|
||||
),
|
||||
"label_seq_id": CIFColumn(
|
||||
data=CIFData(array=self.residue_index, dtype=np.int64)
|
||||
),
|
||||
"ordinal_id": CIFColumn(
|
||||
data=CIFData(array=self.residue_index, dtype=np.int64)
|
||||
),
|
||||
# hard coded to show these are all local plDDT values
|
||||
"metric_id": CIFColumn(
|
||||
data=CIFData(array=["2"] * len(self.residue_index), dtype=np.str_)
|
||||
),
|
||||
"metric_value": CIFColumn(
|
||||
data=CIFData(array=self.confidence, dtype=np.float32)
|
||||
),
|
||||
# hard coded to show there are the initial version, there are no revisions
|
||||
"model_id": CIFColumn(
|
||||
data=CIFData(array=["1"] * len(self.residue_index), dtype=np.str_)
|
||||
),
|
||||
}
|
||||
f.block["ma_qa_metric_local"] = CIFCategory(
|
||||
name="ma_qa_metric_local", columns=resid_pldd_table
|
||||
)
|
||||
f.write(path)
|
||||
|
||||
def to_mmcif_string(self) -> str:
|
||||
buf = io.StringIO()
|
||||
self.to_mmcif(buf)
|
||||
buf.seek(0)
|
||||
return buf.read()
|
||||
|
||||
def state_dict(self, backbone_only=False, json_serializable=False):
|
||||
"""This state dict is optimized for storage, so it turns things to fp16 whenever
|
||||
possible. Note that we also only support int32 residue indices, I'm hoping we don't
|
||||
need more than 2**32 residues..."""
|
||||
dct = {k: v for k, v in asdict(self).items()}
|
||||
dct = {k: v for k, v in asdict(self).items() if k not in ["mmcif"]}
|
||||
if backbone_only:
|
||||
dct["atom37_mask"][:, 3:] = False
|
||||
dct["atom37_positions"] = dct["atom37_positions"][dct["atom37_mask"]]
|
||||
@@ -265,7 +358,11 @@ class ProteinChain:
|
||||
return dct
|
||||
|
||||
def to_blob(self, backbone_only=False) -> bytes:
|
||||
return brotli.compress(msgpack.dumps(self.state_dict(backbone_only)))
|
||||
return brotli.compress(msgpack.dumps(self.state_dict(backbone_only)), quality=5)
|
||||
|
||||
@classmethod
|
||||
def from_open_source(cls, pc: ProteinChain):
|
||||
return cls(**vars(pc))
|
||||
|
||||
@classmethod
|
||||
def from_state_dict(cls, dct):
|
||||
@@ -280,11 +377,11 @@ class ProteinChain:
|
||||
k: (v.astype(np.float32) if k in ["atom37_positions", "confidence"] else v)
|
||||
for k, v in dct.items()
|
||||
}
|
||||
return cls(**dct)
|
||||
return cls(**dct, mmcif=None)
|
||||
|
||||
@classmethod
|
||||
def from_blob(cls, input: Path | str | io.BytesIO | bytes):
|
||||
"""NOTE: blob + sparse coding + brotli + fp16 reduces memory
|
||||
"""NOTE(@zlin): blob + sparse coding + brotli + fp16 reduces memory
|
||||
of chains from 52G/1M chains to 20G/1M chains, I think this is a good first
|
||||
shot at compressing and dumping chains to disk. I'm sure there's better ways."""
|
||||
match input:
|
||||
@@ -296,24 +393,105 @@ class ProteinChain:
|
||||
bytes = input
|
||||
return cls.from_state_dict(msgpack.loads(brotli.decompress(bytes)))
|
||||
|
||||
def dssp(self):
|
||||
dssp = DsspApp.annotate_sse(self.atom_array_no_insertions)
|
||||
full_dssp = np.full(len(self.sequence), "X", dtype="<U1")
|
||||
full_dssp[self.atom37_mask.any(-1)] = dssp
|
||||
return full_dssp
|
||||
|
||||
def sasa(self):
|
||||
def sasa(self, by_residue: bool = True):
|
||||
arr = self.atom_array_no_insertions
|
||||
sasa_per_atom = bs.sasa(arr) # type: ignore
|
||||
# Sum per-atom SASA into residue "bins", with np.bincount.
|
||||
if by_residue:
|
||||
# Sum per-atom SASA into residue "bins", with np.bincount.
|
||||
assert arr.res_id is not None
|
||||
# NOTE(rverkuil): arr.res_id is 1-indexed, but np.bincount returns a sum for bin 0, so we strip.
|
||||
# NOTE(aderry): We compute only for residues with coordinates, return NaN otherwise.
|
||||
num_trailing_residues = len(self) - arr.res_id.max()
|
||||
sasa_per_residue = np.concatenate(
|
||||
[
|
||||
np.bincount(arr.res_id, weights=sasa_per_atom)[1:],
|
||||
np.zeros(num_trailing_residues),
|
||||
]
|
||||
)
|
||||
sasa_per_residue[~self.atom37_mask.any(-1)] = np.nan
|
||||
assert len(sasa_per_residue) == len(self)
|
||||
return sasa_per_residue
|
||||
return sasa_per_atom
|
||||
|
||||
def sap_score(self, aggregation: str = "atom") -> np.ndarray:
|
||||
"""Computes per-atom SAP score.
|
||||
Can optionally aggregate by residue (by averaging over atoms. NOTE: this returns values only for residues that have coordinates!)
|
||||
or full-protein (sum of SAP score for atoms with SAP > 0, as in Lauer et al. 2011)."""
|
||||
sap_radius = 5.0
|
||||
arr = self.atom_array_no_insertions
|
||||
|
||||
# asserts to avoid type errors
|
||||
assert arr.res_id is not None
|
||||
assert np.array_equal(
|
||||
np.sort(np.unique(arr.res_id)), np.arange(1, arr.res_id.max() + 1)
|
||||
), "SASA calculation expected contiguous res_ids in range(1, len(chain)+1)"
|
||||
# NOTE: arr.res_id is 1-indexed, but np.bincount returns a sum for bin 0, so we strip.
|
||||
sasa_per_residue = np.bincount(arr.res_id, weights=sasa_per_atom)[1:]
|
||||
assert len(sasa_per_residue) == len(self)
|
||||
return sasa_per_residue
|
||||
assert arr.res_name is not None
|
||||
assert arr.atom_name is not None
|
||||
assert arr.coord is not None
|
||||
|
||||
# compute SASA and residue-specific properties
|
||||
sasa_per_atom = self.sasa(by_residue=False)
|
||||
resid_to_resname = dict(zip(arr.res_id, arr.res_name))
|
||||
|
||||
max_side_chain_asa = np.full(len(self), np.nan)
|
||||
res_hydrophobicity = np.full(len(self), np.nan)
|
||||
resolved_res_mask = self.atom37_mask.any(-1)
|
||||
num_trailing_residues = len(self) - arr.res_id.max()
|
||||
|
||||
max_side_chain_asa[resolved_res_mask] = np.array(
|
||||
[
|
||||
residue_constants.side_chain_asa[resid_to_resname[i]]
|
||||
for i in np.unique(arr.res_id)
|
||||
]
|
||||
)
|
||||
res_hydrophobicity[resolved_res_mask] = np.array(
|
||||
[
|
||||
residue_constants.hydrophobicity[resid_to_resname[i]]
|
||||
for i in np.unique(arr.res_id)
|
||||
]
|
||||
)
|
||||
assert len(max_side_chain_asa) == len(self)
|
||||
assert len(res_hydrophobicity) == len(self)
|
||||
|
||||
# compute SAP score
|
||||
is_side_chain = ~bs.filter_peptide_backbone(arr)
|
||||
sasa_per_atom[is_side_chain] = 0
|
||||
kdtree = KDTree(arr.coord)
|
||||
neighbors = kdtree.query_ball_tree(kdtree, sap_radius, p=2.0)
|
||||
sap_by_atom = np.zeros_like(sasa_per_atom)
|
||||
for i, nn_list in enumerate(neighbors):
|
||||
saa_nn = np.zeros_like(sasa_per_atom)
|
||||
saa_nn[nn_list] = sasa_per_atom[nn_list]
|
||||
sasa_within_r = np.concatenate(
|
||||
[
|
||||
np.bincount(arr.res_id, weights=saa_nn)[1:],
|
||||
np.zeros(num_trailing_residues),
|
||||
]
|
||||
)
|
||||
sap = np.nansum((sasa_within_r / max_side_chain_asa) * res_hydrophobicity)
|
||||
sap_by_atom[i] = sap
|
||||
|
||||
match aggregation:
|
||||
case "atom":
|
||||
return sap_by_atom
|
||||
case "residue":
|
||||
sap_by_residue = np.concatenate(
|
||||
[
|
||||
np.bincount(arr.res_id, weights=sap_by_atom)[1:],
|
||||
np.zeros(num_trailing_residues),
|
||||
]
|
||||
) / (
|
||||
np.concatenate(
|
||||
[np.bincount(arr.res_id)[1:], np.zeros(num_trailing_residues)]
|
||||
)
|
||||
+ 1e-8
|
||||
)
|
||||
sap_by_residue[~resolved_res_mask] = np.nan
|
||||
assert len(sap_by_residue) == len(self)
|
||||
return sap_by_residue
|
||||
case "protein":
|
||||
return sum(sap_by_atom[sap_by_atom > 0]) # pyright: ignore[reportReturnType]
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Invalid aggregation method: {aggregation}. Must be one of 'atom', 'residue', or 'protein'"
|
||||
)
|
||||
|
||||
def globularity(self) -> float:
|
||||
# Computes globularity using total volumes divided by MVEE.
|
||||
@@ -325,10 +503,10 @@ class ProteinChain:
|
||||
# NOTE(@zeming): due to the approximation we make here, that atoms never overlap, you might get >1 globularity
|
||||
mask = self.atom37_mask.any(-1)
|
||||
points = self.atom37_positions[self.atom37_mask]
|
||||
sequence = [aa for aa, m in zip(self.sequence, mask) if m]
|
||||
sequence = [aa for aa, m in zip(self.sequence, mask) if m] # type: ignore
|
||||
A, _ = self._mvee(points, tol=1e-3)
|
||||
mvee_volume = (4 * np.pi) / (3 * np.sqrt(np.linalg.det(A)))
|
||||
volume = sum(RC.amino_acid_volumes[x] for x in sequence)
|
||||
volume = sum(residue_constants.amino_acid_volumes[x] for x in sequence)
|
||||
ratio = volume / mvee_volume
|
||||
|
||||
# The paper says you must compare the ellipsoidal profile with T, a measurement of
|
||||
@@ -388,6 +566,10 @@ class ProteinChain:
|
||||
|
||||
return A, c
|
||||
|
||||
def radius_of_gyration(self):
|
||||
arr = self.atom_array_no_insertions
|
||||
return bs.gyration_radius(arr)
|
||||
|
||||
def align(
|
||||
self,
|
||||
target: ProteinChain,
|
||||
@@ -462,8 +644,8 @@ class ProteinChain:
|
||||
target_inds: list[int] | np.ndarray | None = None,
|
||||
**kwargs,
|
||||
) -> float | np.ndarray:
|
||||
"""Compute the LDDT between this protein chain and another.
|
||||
NOTE: LDDT IS NOT SYMMETRIC. The call should always be prediction.lddt_ca(native).
|
||||
"""Compute the LDDT between this protein chain and another. NOTE: LDDT IS NOT SYMMETRIC.
|
||||
The call should always be prediction.lddt_ca(native).
|
||||
|
||||
Arguments:
|
||||
native (ProteinChain): The ground truth protein chain
|
||||
@@ -482,6 +664,171 @@ class ProteinChain:
|
||||
)
|
||||
return float(lddt) if lddt.numel() == 1 else lddt.numpy().flatten()
|
||||
|
||||
def gdt_ts(
|
||||
self,
|
||||
target: ProteinChain,
|
||||
mobile_inds: list[int] | np.ndarray | None = None,
|
||||
target_inds: list[int] | np.ndarray | None = None,
|
||||
**kwargs,
|
||||
) -> float | np.ndarray:
|
||||
"""Compute the GDT_TS between this protein chain and another.
|
||||
|
||||
Arguments:
|
||||
target (ProteinChain): The other protein chain to compare to.
|
||||
mobile_inds (list[int], np.ndarray, optional): The indices of the mobile atoms to align. These are NOT residue indices
|
||||
target_inds (list[int], np.ndarray, optional): The indices of the target atoms to align. These are NOT residue indices
|
||||
|
||||
Returns:
|
||||
float: The GDT_TS score between the two protein chains.
|
||||
"""
|
||||
gdt_ts = compute_gdt_ts(
|
||||
mobile=torch.tensor(
|
||||
index_by_atom_name(self.atom37_positions[mobile_inds], "CA"),
|
||||
dtype=torch.float32,
|
||||
).unsqueeze(0),
|
||||
target=torch.tensor(
|
||||
index_by_atom_name(target.atom37_positions[target_inds], "CA"),
|
||||
dtype=torch.float32,
|
||||
).unsqueeze(0),
|
||||
atom_exists_mask=torch.tensor(
|
||||
index_by_atom_name(self.atom37_mask[mobile_inds], "CA", dim=-1)
|
||||
& index_by_atom_name(target.atom37_mask[target_inds], "CA", dim=-1)
|
||||
).unsqueeze(0),
|
||||
**kwargs,
|
||||
)
|
||||
return float(gdt_ts) if gdt_ts.numel() == 1 else gdt_ts.numpy().flatten()
|
||||
|
||||
@classmethod
|
||||
def chain_iterable_from_mmcif(
|
||||
cls,
|
||||
path: PathOrBuffer | MmcifWrapper,
|
||||
id: str | None = None,
|
||||
is_predicted: bool = False,
|
||||
keep_source: bool = False,
|
||||
):
|
||||
"""Return a list[ProteinChain] object from an mmcif file, a iterable list of all protein chain
|
||||
from an mmcif file
|
||||
"""
|
||||
if isinstance(path, MmcifWrapper):
|
||||
mmcif = path
|
||||
else:
|
||||
mmcif = MmcifWrapper.read(path, id)
|
||||
for chain in bs.chain_iter(mmcif.structure):
|
||||
chain = chain[bs.filter_amino_acids(chain) & ~chain.hetero]
|
||||
if len(chain) == 0:
|
||||
continue
|
||||
chain_id = chain.chain_id[0]
|
||||
entity_id = None
|
||||
for entity, chains in mmcif.entities.items():
|
||||
if chain_id in chains:
|
||||
entity_id = entity
|
||||
assert entity_id is not None
|
||||
(
|
||||
sequence,
|
||||
atom_positions,
|
||||
atom_mask,
|
||||
residue_index,
|
||||
insertion_code,
|
||||
confidence,
|
||||
_,
|
||||
) = chain_to_ndarray(chain, mmcif, chain_id, is_predicted)
|
||||
assert all(sequence), "Some residue name was not specified correctly"
|
||||
|
||||
yield cls(
|
||||
id=mmcif.id,
|
||||
sequence=sequence,
|
||||
chain_id=chain_id,
|
||||
entity_id=entity_id,
|
||||
atom37_positions=atom_positions,
|
||||
atom37_mask=atom_mask,
|
||||
residue_index=residue_index,
|
||||
insertion_code=insertion_code,
|
||||
confidence=confidence,
|
||||
mmcif=mmcif if keep_source else None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_mmcif(
|
||||
cls,
|
||||
path: PathOrBuffer | MmcifWrapper,
|
||||
chain_id: str | None = None,
|
||||
entity_id: int | None = None,
|
||||
id: str | None = None,
|
||||
is_predicted: bool = False,
|
||||
keep_source: bool = False,
|
||||
):
|
||||
"""Return a ProteinChain object from an mmcif file.
|
||||
|
||||
Args:
|
||||
path (str | Path | io.TextIO): Path or buffer to read mmcif file from. Should be uncompressed.
|
||||
id (str, optional): String identifier to assign to structure. Will attempt to infer otherwise.
|
||||
is_predicted (bool): If True, reads b factor as the confidence readout. Default: False.
|
||||
chain_id (str, optional): Select a chain corresponding to (author) chain id.
|
||||
entity_id (int, optional): Select a chain corresponding to a particular entity.
|
||||
|
||||
If neither `chain_id` nor `entity_id` is specified, defaults to the first entity.
|
||||
"""
|
||||
if isinstance(path, MmcifWrapper):
|
||||
mmcif = path
|
||||
else:
|
||||
mmcif = MmcifWrapper.read(path, id)
|
||||
|
||||
# If neither chain_id nor entity_id is specified, default to the first entity
|
||||
if chain_id is None and entity_id is None:
|
||||
if not mmcif.entities:
|
||||
raise ValueError("Structure contains no entities")
|
||||
entity_id = min(mmcif.entities.keys()) # Pick the first entity by ID
|
||||
|
||||
if entity_id is not None:
|
||||
assert chain_id is None
|
||||
if entity_id not in mmcif.entities:
|
||||
raise ValueError(
|
||||
f"Structure does not contain entity `{entity_id}`. Valid entities: {mmcif.entities.keys()}"
|
||||
)
|
||||
chains = mmcif.entities[entity_id]
|
||||
|
||||
# Select the chain id corresponding to the longest chain. If all are equal length, selects the first.
|
||||
chain_id = max(
|
||||
chains,
|
||||
key=lambda chain: _num_non_null_residues(
|
||||
mmcif.seqres_to_structure[chain]
|
||||
),
|
||||
)
|
||||
else:
|
||||
assert chain_id is not None
|
||||
for entity, chains in mmcif.entities.items():
|
||||
if chain_id in chains:
|
||||
entity_id = entity
|
||||
if entity_id is None:
|
||||
warnings.warn(
|
||||
"Failed to detect entity_id from mmcif file, it may be malformed."
|
||||
)
|
||||
|
||||
atom_array = mmcif.structure
|
||||
(
|
||||
sequence,
|
||||
atom_positions,
|
||||
atom_mask,
|
||||
residue_index,
|
||||
insertion_code,
|
||||
confidence,
|
||||
_,
|
||||
) = chain_to_ndarray(atom_array, mmcif, chain_id, is_predicted)
|
||||
assert all(sequence), "Some residue name was not specified correctly"
|
||||
|
||||
return cls(
|
||||
id=mmcif.id,
|
||||
sequence=sequence,
|
||||
chain_id=chain_id,
|
||||
entity_id=entity_id,
|
||||
atom37_positions=atom_positions,
|
||||
atom37_mask=atom_mask.astype(bool),
|
||||
residue_index=residue_index,
|
||||
insertion_code=insertion_code,
|
||||
confidence=confidence,
|
||||
mmcif=mmcif if keep_source else None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_atom37(
|
||||
cls,
|
||||
@@ -549,11 +896,11 @@ class ProteinChain:
|
||||
|
||||
return cls(
|
||||
id=id,
|
||||
sequence=sequence,
|
||||
sequence=sequence, # type: ignore
|
||||
chain_id=chain_id,
|
||||
entity_id=entity_id,
|
||||
atom37_positions=atom37_positions,
|
||||
atom37_mask=atom_mask,
|
||||
atom37_mask=atom_mask.astype(bool),
|
||||
residue_index=residue_index,
|
||||
insertion_code=insertion_code,
|
||||
confidence=confidence,
|
||||
@@ -604,10 +951,12 @@ class ProteinChain:
|
||||
id: str | None = None,
|
||||
is_predicted: bool = False,
|
||||
) -> "ProteinChain":
|
||||
"""Return a ProteinChain object from an pdb file.
|
||||
"""Return a ProteinChain object from an pdb file. NOTE: prefer mmcif for rcsb PDB files.
|
||||
This function is mostly to interface with old PDB files and predicted structures -
|
||||
it will not fill out the entity id correctly
|
||||
|
||||
Args:
|
||||
path (str | Path | io.TextIO): Path or buffer to read pdb file from. Should be uncompressed.
|
||||
path (str | Path | io.TextIO): Path or buffer to read mmcif file from. Should be uncompressed.
|
||||
id (str, optional): String identifier to assign to structure. Will attempt to infer otherwise.
|
||||
is_predicted (bool): If True, reads b factor as the confidence readout. Default: False.
|
||||
chain_id (str, optional): Select a chain corresponding to (author) chain id. "detect" uses the
|
||||
@@ -637,20 +986,17 @@ class ProteinChain:
|
||||
entity_id = 1 # Not supplied in PDBfiles
|
||||
|
||||
sequence = "".join(
|
||||
(
|
||||
r
|
||||
if len(r := PDBData.protein_letters_3to1.get(monomer[0].res_name, "X"))
|
||||
== 1
|
||||
else "X"
|
||||
)
|
||||
residue_constants.restype_3to1.get(monomer[0].res_name, "X")
|
||||
for monomer in bs.residue_iter(atom_array)
|
||||
)
|
||||
num_res = len(sequence)
|
||||
|
||||
atom_positions = np.full(
|
||||
[num_res, RC.atom_type_num, 3], np.nan, dtype=np.float32
|
||||
[num_res, residue_constants.atom_type_num, 3], np.nan, dtype=np.float32
|
||||
)
|
||||
atom_mask = np.full(
|
||||
[num_res, residue_constants.atom_type_num], False, dtype=bool
|
||||
)
|
||||
atom_mask = np.full([num_res, RC.atom_type_num], False, dtype=bool)
|
||||
residue_index = np.full([num_res], -1, dtype=np.int64)
|
||||
insertion_code = np.full([num_res], "", dtype="<U4")
|
||||
|
||||
@@ -671,9 +1017,11 @@ class ProteinChain:
|
||||
# Put the coords of the selenium atom in the sulphur column
|
||||
atom_name = "SD"
|
||||
|
||||
if atom_name in RC.atom_order:
|
||||
atom_positions[i, RC.atom_order[atom_name]] = atom.coord
|
||||
atom_mask[i, RC.atom_order[atom_name]] = True
|
||||
if atom_name in residue_constants.atom_order:
|
||||
atom_positions[i, residue_constants.atom_order[atom_name]] = (
|
||||
atom.coord
|
||||
)
|
||||
atom_mask[i, residue_constants.atom_order[atom_name]] = True
|
||||
if is_predicted and atom_name == "CA":
|
||||
confidence[i] = atom.b_factor
|
||||
|
||||
@@ -685,21 +1033,49 @@ class ProteinChain:
|
||||
chain_id=chain_id,
|
||||
entity_id=entity_id,
|
||||
atom37_positions=atom_positions,
|
||||
atom37_mask=atom_mask,
|
||||
atom37_mask=atom_mask.astype(bool),
|
||||
residue_index=residue_index,
|
||||
insertion_code=insertion_code,
|
||||
confidence=confidence,
|
||||
mmcif=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_rcsb(cls, pdb_id: str, chain_id: str = "detect"):
|
||||
"""Fetch a protein chain from the RCSB PDB database."""
|
||||
f: io.StringIO = rcsb.fetch(pdb_id, "pdb") # type: ignore
|
||||
return cls.from_pdb(f, chain_id=chain_id, id=pdb_id)
|
||||
def from_mds(cls, data: dict[str, Any]) -> "ProteinChain":
|
||||
return cls(
|
||||
id=data["id"],
|
||||
chain_id=data["chain_id"],
|
||||
entity_id=data["entity_id"],
|
||||
sequence=data["sequence"],
|
||||
residue_index=data["residue_index"],
|
||||
insertion_code=np.asarray(data["insertion_code"]),
|
||||
atom37_positions=data["atom37_positions"],
|
||||
atom37_mask=data["atom37_mask"].astype(bool),
|
||||
confidence=data["confidence"],
|
||||
mmcif=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_rcsb(
|
||||
cls,
|
||||
pdb_id: str,
|
||||
chain_id: str | None = None,
|
||||
entity_id: int | None = None,
|
||||
keep_source: bool = False,
|
||||
) -> ProteinChain:
|
||||
f: io.StringIO = rcsb.fetch(pdb_id, "cif") # type: ignore
|
||||
return cls.from_mmcif(
|
||||
f,
|
||||
id=pdb_id,
|
||||
chain_id=chain_id,
|
||||
entity_id=entity_id,
|
||||
keep_source=keep_source,
|
||||
is_predicted=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_atomarray(
|
||||
cls, atom_array: bs.AtomArray, id: str | None = None
|
||||
cls, atom_array: bs.AtomArray, id: str | None = None, is_predicted: bool = False
|
||||
) -> "ProteinChain":
|
||||
"""A simple converter from bs.AtomArray -> ProteinChain.
|
||||
Uses PDB file format as intermediate."""
|
||||
@@ -711,7 +1087,7 @@ class ProteinChain:
|
||||
buf = io.StringIO()
|
||||
pdb_file.write(buf)
|
||||
buf.seek(0)
|
||||
return cls.from_pdb(buf, id=id)
|
||||
return cls.from_pdb(buf, id=id, is_predicted=is_predicted)
|
||||
|
||||
def get_normalization_frame(self) -> Affine3D:
|
||||
"""Given a set of coordinates, compute a single frame.
|
||||
@@ -745,6 +1121,8 @@ class ProteinChain:
|
||||
|
||||
def infer_oxygen(self) -> ProteinChain:
|
||||
"""Oxygen position is fixed given N, CA, C atoms. Infer it if not provided."""
|
||||
O_missing_indices = np.argwhere(np.isnan(self.atoms["O"]).any(axis=1)).squeeze()
|
||||
|
||||
O_vector = torch.tensor([0.6240, -1.0613, 0.0103], dtype=torch.float32)
|
||||
N, CA, C = torch.from_numpy(self.atoms[["N", "CA", "C"]]).float().unbind(dim=1)
|
||||
N = torch.roll(N, -3)
|
||||
@@ -756,9 +1134,11 @@ class ProteinChain:
|
||||
atom37_positions = self.atom37_positions.copy()
|
||||
atom37_mask = self.atom37_mask.copy()
|
||||
|
||||
atom37_positions[:, RC.atom_order["O"]] = O.numpy()
|
||||
atom37_mask[:, RC.atom_order["O"]] = ~np.isnan(
|
||||
atom37_positions[:, RC.atom_order["O"]]
|
||||
atom37_positions[O_missing_indices, residue_constants.atom_order["O"]] = O[
|
||||
O_missing_indices
|
||||
].numpy()
|
||||
atom37_mask[O_missing_indices, residue_constants.atom_order["O"]] = ~np.isnan(
|
||||
atom37_positions[O_missing_indices, residue_constants.atom_order["O"]]
|
||||
).any(-1)
|
||||
new_chain = replace(
|
||||
self, atom37_positions=atom37_positions, atom37_mask=atom37_mask
|
||||
@@ -781,7 +1161,7 @@ class ProteinChain:
|
||||
infer_cbeta_for_glycine (bool): If True, infers a beta carbon for glycine
|
||||
residues, even though that residue doesn't have one. Default off.
|
||||
|
||||
NOTE: The reason for having this switch in the first place
|
||||
NOTE(rverkuil): The reason for having this switch in the first place
|
||||
is that sometimes we want a (inferred) CB coordinate for every residue,
|
||||
for example for making a pairwise distance matrix, or doing an RMSD
|
||||
calculation between two designs for a given structural template, w/
|
||||
@@ -792,11 +1172,13 @@ class ProteinChain:
|
||||
|
||||
inferred_cbeta_positions = self.inferred_cbeta
|
||||
if not infer_cbeta_for_glycine:
|
||||
inferred_cbeta_positions[np.array(list(self.sequence)) == "G", :] = np.NAN
|
||||
inferred_cbeta_positions[np.array(list(self.sequence)) == "G", :] = np.nan
|
||||
|
||||
atom37_positions[:, RC.atom_order["CB"]] = inferred_cbeta_positions
|
||||
atom37_mask[:, RC.atom_order["CB"]] = ~np.isnan(
|
||||
atom37_positions[:, RC.atom_order["CB"]]
|
||||
atom37_positions[:, residue_constants.atom_order["CB"]] = (
|
||||
inferred_cbeta_positions
|
||||
)
|
||||
atom37_mask[:, residue_constants.atom_order["CB"]] = ~np.isnan(
|
||||
atom37_positions[:, residue_constants.atom_order["CB"]]
|
||||
).any(-1)
|
||||
new_chain = replace(
|
||||
self, atom37_positions=atom37_positions, atom37_mask=atom37_mask
|
||||
@@ -822,36 +1204,69 @@ class ProteinChain:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def concat(cls, chains: Sequence[ProteinChain]):
|
||||
def join_arrays(arrays: Sequence[np.ndarray], sep: np.ndarray):
|
||||
full_array = []
|
||||
for array in arrays:
|
||||
full_array.append(array)
|
||||
full_array.append(sep)
|
||||
full_array = full_array[:-1]
|
||||
return np.concatenate(full_array, 0)
|
||||
|
||||
def concat(cls, chains: Sequence[ProteinChain], use_chainbreak: bool = True):
|
||||
sep_tokens = {
|
||||
"residue_index": np.array([-1]),
|
||||
"insertion_code": np.array([""]),
|
||||
"atom37_positions": np.full([1, 37, 3], np.nan),
|
||||
"atom37_mask": np.zeros([1, 37]),
|
||||
"atom37_positions": np.full([1, 37, 3], np.inf),
|
||||
"atom37_mask": np.zeros([1, 37], dtype=bool),
|
||||
"confidence": np.array([0]),
|
||||
}
|
||||
|
||||
def join_arrays(arrays: Sequence[np.ndarray], sep: np.ndarray):
|
||||
if use_chainbreak:
|
||||
full_array = []
|
||||
for array in arrays:
|
||||
full_array.append(array)
|
||||
full_array.append(sep)
|
||||
full_array = full_array[:-1]
|
||||
return np.concatenate(full_array, 0)
|
||||
else:
|
||||
return np.concatenate(arrays, 0)
|
||||
|
||||
array_args: dict[str, np.ndarray] = {
|
||||
name: join_arrays([getattr(chain, name) for chain in chains], sep)
|
||||
for name, sep in sep_tokens.items()
|
||||
}
|
||||
|
||||
chain_break = residue_constants.CHAIN_BREAK_TOKEN if use_chainbreak else ""
|
||||
return cls(
|
||||
id=chains[0].id,
|
||||
sequence=C.CHAIN_BREAK_STR.join(chain.sequence for chain in chains),
|
||||
sequence=chain_break.join(chain.sequence for chain in chains),
|
||||
chain_id="A",
|
||||
entity_id=None,
|
||||
mmcif=None,
|
||||
**array_args,
|
||||
)
|
||||
|
||||
def find_nonpolymer_contacts(self):
|
||||
assert self.mmcif is not None
|
||||
nonpolymer_and_chain_id_to_array = self.mmcif.non_polymer_coords
|
||||
|
||||
results = []
|
||||
for (
|
||||
nonpolymer,
|
||||
_,
|
||||
), nonpolymer_array in nonpolymer_and_chain_id_to_array.items():
|
||||
assert nonpolymer_array.coord is not None
|
||||
chain_coords = self.atom37_positions[self.atom37_mask]
|
||||
distance = cdist(nonpolymer_array.coord, chain_coords)
|
||||
|
||||
is_contact = distance < 5
|
||||
if not is_contact.any():
|
||||
continue
|
||||
contacting_atoms = np.where(is_contact.any(0))[0]
|
||||
chain_index = np.where(self.atom37_mask)[0]
|
||||
contacting_residues = np.unique(chain_index[contacting_atoms])
|
||||
|
||||
result = {
|
||||
"ligand": nonpolymer.name,
|
||||
"ligand_id": nonpolymer.comp_id,
|
||||
"contacting_residues": contacting_residues.tolist(),
|
||||
}
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
def select_residue_indices(
|
||||
self, indices: list[int | str], ignore_x_mismatch: bool = False
|
||||
) -> ProteinChain:
|
||||
@@ -876,3 +1291,25 @@ class ProteinChain:
|
||||
raise RuntimeError(mismatch_str)
|
||||
|
||||
return new
|
||||
|
||||
def to_structure_encoder_inputs(
|
||||
self,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Convert protein chain to structure encoder inputs.
|
||||
|
||||
Returns:
|
||||
tuple: (coordinates, plddt, residue_index) where:
|
||||
- coordinates: (1, L, 37, 3) tensor of atom positions
|
||||
- plddt: (1, L) tensor of confidence scores
|
||||
- residue_index: (1, L) tensor of residue indices
|
||||
"""
|
||||
# Convert to tensors and add batch dimension
|
||||
coordinates = (
|
||||
torch.from_numpy(self.atom37_positions).float().unsqueeze(0)
|
||||
) # (1, L, 37, 3)
|
||||
plddt = torch.from_numpy(self.confidence).float().unsqueeze(0) # (1, L)
|
||||
residue_index = (
|
||||
torch.from_numpy(self.residue_index).long().unsqueeze(0)
|
||||
) # (1, L)
|
||||
|
||||
return coordinates, plddt, residue_index
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import io
|
||||
import itertools
|
||||
import random
|
||||
import re
|
||||
import warnings
|
||||
from dataclasses import asdict, dataclass, replace
|
||||
@@ -18,24 +19,28 @@ import msgpack_numpy
|
||||
import numpy as np
|
||||
import torch
|
||||
from biotite.database import rcsb
|
||||
from biotite.file import InvalidFileError
|
||||
from biotite.structure.io.pdb import PDBFile
|
||||
from biotite.structure.io.pdbx import CIFCategory, CIFColumn, CIFData, CIFFile
|
||||
from biotite.structure.io.pdbx import set_structure as set_structure_pdbx
|
||||
from biotite.structure.io.pdbx.convert import _get_transformations, get_structure
|
||||
from biotite.structure.util import matrix_rotate
|
||||
from scipy.spatial import KDTree
|
||||
|
||||
from esm.utils import residue_constants
|
||||
from esm.utils.constants import esm3 as esm3_c
|
||||
from esm.utils.misc import slice_python_object_as_numpy
|
||||
from esm.utils.structure.affine3d import Affine3D
|
||||
from esm.utils.structure.aligner import Aligner
|
||||
from esm.utils.structure.metrics import (
|
||||
compute_gdt_ts,
|
||||
compute_lddt_ca,
|
||||
)
|
||||
from esm.utils.structure.atom_indexer import AtomIndexer
|
||||
from esm.utils.structure.metrics import compute_gdt_ts, compute_lddt_ca
|
||||
from esm.utils.structure.mmcif_parsing import MmcifWrapper, NoProteinError
|
||||
from esm.utils.structure.protein_chain import (
|
||||
PathOrBuffer,
|
||||
ProteinChain,
|
||||
)
|
||||
from esm.utils.structure.protein_structure import (
|
||||
chain_to_ndarray,
|
||||
index_by_atom_name,
|
||||
infer_CB,
|
||||
)
|
||||
from esm.utils.types import PathOrBuffer
|
||||
|
||||
msgpack_numpy.patch()
|
||||
|
||||
@@ -44,35 +49,72 @@ SINGLE_LETTER_CHAIN_IDS = (
|
||||
)
|
||||
|
||||
|
||||
def protein_chain_to_protein_complex(chain: ProteinChain) -> ProteinComplex:
|
||||
if "|" not in chain.sequence:
|
||||
return ProteinComplex.from_chains([chain])
|
||||
chain_breaks = np.array(list(chain.sequence)) == "|"
|
||||
chain_break_inds = np.where(chain_breaks)[0]
|
||||
chain_break_inds = np.concatenate([[0], chain_break_inds, [len(chain)]])
|
||||
chain_break_inds = np.array(list(zip(chain_break_inds[:-1], chain_break_inds[1:])))
|
||||
complex_chains = []
|
||||
for start, end in chain_break_inds:
|
||||
if start != 0:
|
||||
start += 1
|
||||
complex_chains.append(chain[start:end])
|
||||
complex_chains = [
|
||||
ProteinChain.from_atom37(
|
||||
chain.atom37_positions,
|
||||
sequence=chain.sequence,
|
||||
chain_id=SINGLE_LETTER_CHAIN_IDS[i],
|
||||
entity_id=i,
|
||||
)
|
||||
for i, chain in enumerate(complex_chains)
|
||||
]
|
||||
return ProteinComplex.from_chains(complex_chains)
|
||||
def _parse_operation_expression(expression):
|
||||
"""
|
||||
Get successive operation steps (IDs) for the given
|
||||
``oper_expression``.
|
||||
Form the cartesian product, if necessary.
|
||||
Copied from biotite and fixed a bug
|
||||
"""
|
||||
# Split groups by parentheses:
|
||||
# use the opening parenthesis as delimiter
|
||||
# and just remove the closing parenthesis
|
||||
expressions_per_step = expression.replace(")", "").split("(")
|
||||
expressions_per_step = [e for e in expressions_per_step if len(e) > 0]
|
||||
# Important: Operations are applied from right to left
|
||||
expressions_per_step.reverse()
|
||||
|
||||
operations = []
|
||||
for expr in expressions_per_step:
|
||||
cur_expr = expr.split(",")
|
||||
cur_op = []
|
||||
# Deal with e='1-10,20-30,40-50' type expressions
|
||||
for e in cur_expr:
|
||||
if "-" in e:
|
||||
first, last = e.split("-")
|
||||
cur_op.extend(str(id) for id in range(int(first), int(last) + 1))
|
||||
else:
|
||||
cur_op.append(e)
|
||||
operations.append(cur_op)
|
||||
|
||||
# Cartesian product of operations
|
||||
return list(itertools.product(*operations))
|
||||
|
||||
|
||||
def _apply_transformations_fast(chains, transformation_dict, operations):
|
||||
"""
|
||||
Get subassembly by applying the given operations to the input
|
||||
structure containing affected asym IDs.
|
||||
"""
|
||||
# Additional first dimesion for 'structure.repeat()'
|
||||
results = []
|
||||
|
||||
# Apply corresponding transformation for each copy in the assembly
|
||||
for c in chains:
|
||||
for operation in operations:
|
||||
coord = c.atom37_positions.copy()
|
||||
# Execute for each transformation step
|
||||
# in the operation expression
|
||||
for op_step in operation:
|
||||
T = transformation_dict[op_step]
|
||||
# Rotate
|
||||
coord = matrix_rotate(coord, T.rotation)
|
||||
# Translate
|
||||
coord += T.target_translation
|
||||
new_chain = replace(c, atom37_positions=coord)
|
||||
results.append(new_chain)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProteinComplexMetadata:
|
||||
entity_lookup: dict[int, int]
|
||||
chain_lookup: dict[int, str]
|
||||
chain_boundaries: list[tuple[int, int]]
|
||||
mmcif: MmcifWrapper | None = None
|
||||
# This is a dictionary that maps assembly ids to the list of unique chains
|
||||
# in that assembly. Allows for usage of `switch_assembly`.
|
||||
assembly_composition: dict[str, list[str]] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -100,19 +142,7 @@ class DockQResult:
|
||||
aligned_rmsd: float
|
||||
|
||||
|
||||
class AtomIndexer:
|
||||
def __init__(self, structure: ProteinComplex, property: str, dim: int):
|
||||
self.structure = structure
|
||||
self.property = property
|
||||
self.dim = dim
|
||||
|
||||
def __getitem__(self, atom_names: str | list[str]) -> np.ndarray:
|
||||
return index_by_atom_name(
|
||||
getattr(self.structure, self.property), atom_names, self.dim
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class ProteinComplex:
|
||||
"""Dataclass with atom37 representation of an entire protein complex."""
|
||||
|
||||
@@ -126,6 +156,7 @@ class ProteinComplex:
|
||||
atom37_positions: np.ndarray
|
||||
atom37_mask: np.ndarray
|
||||
confidence: np.ndarray
|
||||
# This metadata is parsed from the MMCIF file. For synthetic data, we do a best effort.
|
||||
metadata: ProteinComplexMetadata
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -195,15 +226,59 @@ class ProteinComplex:
|
||||
def __len__(self):
|
||||
return len(self.sequence)
|
||||
|
||||
@property
|
||||
def num_chains(self):
|
||||
return len(self.chain_boundaries)
|
||||
|
||||
@cached_property
|
||||
def atoms(self) -> AtomIndexer:
|
||||
return AtomIndexer(self, property="atom37_positions", dim=-2)
|
||||
|
||||
@cached_property
|
||||
def atom_mask(self) -> AtomIndexer:
|
||||
return AtomIndexer(self, property="atom37_mask", dim=-1)
|
||||
|
||||
@cached_property
|
||||
def chain_lengths(self) -> np.ndarray:
|
||||
return np.diff(self.chain_boundaries, axis=1).flatten()
|
||||
|
||||
@cached_property
|
||||
def chain_boundaries(self) -> list[tuple[int, int]]:
|
||||
cb = [-1]
|
||||
for i, s in enumerate(self.sequence):
|
||||
if s == "|":
|
||||
cb.append(i)
|
||||
cb.append(len(self))
|
||||
return [(cb[i] + 1, cb[i + 1]) for i in range(len(cb) - 1)]
|
||||
|
||||
def get_chain_by_index(self, index: int) -> ProteinChain:
|
||||
try:
|
||||
start, end = self.chain_boundaries[index]
|
||||
return self[start:end].as_chain()
|
||||
except IndexError:
|
||||
raise IndexError(f"Chain index {index} out of bounds")
|
||||
|
||||
def get_chain_by_id(
|
||||
self, chain_id: str, sample_chain_if_duplicate: bool = True
|
||||
) -> ProteinChain:
|
||||
valid_indices = [
|
||||
index
|
||||
for index, id_of_index in self.metadata.chain_lookup.items()
|
||||
if id_of_index == chain_id
|
||||
]
|
||||
if not valid_indices:
|
||||
raise KeyError(f"Chain ID {chain_id} not found")
|
||||
if sample_chain_if_duplicate:
|
||||
index_to_return = random.choice(valid_indices)
|
||||
return self.get_chain_by_index(index_to_return)
|
||||
else:
|
||||
if len(valid_indices) > 1:
|
||||
raise ValueError(f"Multiple chains with chain ID {chain_id} found")
|
||||
return self.get_chain_by_index(valid_indices[0])
|
||||
|
||||
def chain_iter(self) -> Iterable[ProteinChain]:
|
||||
boundaries = [i for i, s in enumerate(self.sequence) if s == "|"]
|
||||
boundaries = [-1, *boundaries, len(self)]
|
||||
for i in range(len(boundaries) - 1):
|
||||
c = self.__getitem__(slice(boundaries[i] + 1, boundaries[i + 1]))
|
||||
for start, end in self.chain_boundaries:
|
||||
c = self[start:end]
|
||||
yield c.as_chain()
|
||||
|
||||
def as_chain(self, force_conversion: bool = False) -> ProteinChain:
|
||||
@@ -237,6 +312,7 @@ class ProteinComplex:
|
||||
residue_index=self.residue_index,
|
||||
insertion_code=self.insertion_code,
|
||||
confidence=self.confidence,
|
||||
mmcif=self.metadata.mmcif,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -253,12 +329,6 @@ class ProteinComplex:
|
||||
chains.append(ProteinChain.from_atomarray(chain, id))
|
||||
return ProteinComplex.from_chains(chains)
|
||||
|
||||
@classmethod
|
||||
def from_rcsb(cls, pdb_id: str):
|
||||
"""Fetch a protein complex from the RCSB PDB database."""
|
||||
f: io.StringIO = rcsb.fetch(pdb_id, "pdb") # type: ignore
|
||||
return cls.from_pdb(f, id=pdb_id)
|
||||
|
||||
def to_pdb(self, path: PathOrBuffer, include_insertions: bool = True):
|
||||
atom_array = None
|
||||
for chain in self.chain_iter():
|
||||
@@ -284,13 +354,29 @@ class ProteinComplex:
|
||||
ids = SINGLE_LETTER_CHAIN_IDS
|
||||
chains = []
|
||||
for i, chain in enumerate(self.chain_iter()):
|
||||
chain.chain_id = ids[i]
|
||||
chain = replace(chain, chain_id=ids[i])
|
||||
if i > len(ids):
|
||||
raise RuntimeError("Too many chains to write to PDB file")
|
||||
chains.append(chain)
|
||||
|
||||
return ProteinComplex.from_chains(chains)
|
||||
|
||||
def find_assembly_ids_with_chain(self, id: str) -> list[str]:
|
||||
good_chains = []
|
||||
if (comp := self.metadata.assembly_composition) is not None:
|
||||
for assembly_id, chain_ids in comp.items():
|
||||
if id in chain_ids:
|
||||
good_chains.append(assembly_id)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Cannot switch assemblies on this ProteinComplex, you must create the assembly from mmcif to support this"
|
||||
)
|
||||
return good_chains
|
||||
|
||||
def switch_assembly(self, id: str):
|
||||
assert self.metadata.mmcif is not None
|
||||
return get_assembly_fast(self.metadata.mmcif, assembly_id=id)
|
||||
|
||||
def state_dict(self, backbone_only=False):
|
||||
"""This state dict is optimized for storage, so it turns things to fp16 whenever
|
||||
possible. Note that we also only support int32 residue indices, I'm hoping we don't
|
||||
@@ -308,6 +394,11 @@ class ProteinComplex:
|
||||
elif isinstance(v, ProteinComplexMetadata):
|
||||
dct[k] = asdict(v)
|
||||
dct["atom37_positions"] = dct["atom37_positions"][dct["atom37_mask"]]
|
||||
dct["metadata"]["mmcif"] = None
|
||||
# These can be populated with non-serializable objects and are not needed for reconstruction
|
||||
dct.pop("atoms", None)
|
||||
dct.pop("atom_mask", None)
|
||||
dct.pop("per_chain_kd_trees", None)
|
||||
return dct
|
||||
|
||||
def to_blob(self, backbone_only=False) -> bytes:
|
||||
@@ -322,6 +413,10 @@ class ProteinComplex:
|
||||
k: (v.astype(np.float32) if k in ["atom37_positions", "confidence"] else v)
|
||||
for k, v in dct.items()
|
||||
}
|
||||
if "chain_boundaries" in dct:
|
||||
del dct["chain_boundaries"]
|
||||
if "chain_boundaries" in dct["metadata"]:
|
||||
del dct["metadata"]["chain_boundaries"]
|
||||
dct["metadata"] = ProteinComplexMetadata(**dct["metadata"])
|
||||
return cls(**dct)
|
||||
|
||||
@@ -342,13 +437,45 @@ class ProteinComplex:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_chains(cls, chains: Sequence[ProteinChain]):
|
||||
def from_rcsb(cls, pdb_id: str, keep_source: bool = False) -> ProteinComplex:
|
||||
f: io.StringIO = rcsb.fetch(pdb_id, "cif") # type: ignore
|
||||
return cls.from_mmcif(f, id=pdb_id, keep_source=keep_source, is_predicted=False)
|
||||
|
||||
@classmethod
|
||||
def from_mmcif(
|
||||
cls,
|
||||
path: PathOrBuffer,
|
||||
id: str | None = None,
|
||||
assembly_id: str | None = None,
|
||||
is_predicted: bool = False,
|
||||
keep_source: bool = False,
|
||||
):
|
||||
"""Return a ProteinComplex object from an mmcif file.
|
||||
TODO(@zeming): there's actually multiple complexes per file, but for ease of implementation,
|
||||
we only consider the first defined complex!
|
||||
|
||||
Args:
|
||||
path (str | Path | io.TextIO): Path or buffer to read mmcif file from. Should be uncompressed.
|
||||
id (str, optional): String identifier to assign to structure. Will attempt to infer otherwise.
|
||||
is_predicted (bool): If True, reads b factor as the confidence readout. Default: False.
|
||||
chain_id (str, optional): Select a chain corresponding to (author) chain id.
|
||||
"""
|
||||
mmcif = MmcifWrapper.read(path, id)
|
||||
return get_assembly_fast(mmcif, assembly_id=assembly_id)
|
||||
|
||||
@classmethod
|
||||
def from_chains(
|
||||
cls,
|
||||
chains: Sequence[ProteinChain],
|
||||
mmcif: MmcifWrapper | None = None,
|
||||
all_assembly_metadata_dictionary: dict[str, list[str]] | None = None,
|
||||
):
|
||||
if not chains:
|
||||
raise ValueError(
|
||||
"Cannot create a ProteinComplex from an empty list of chains"
|
||||
)
|
||||
|
||||
# TODO: Make a proper protein complex class
|
||||
# TODO(roshan): Make a proper protein complex class
|
||||
def join_arrays(arrays: Sequence[np.ndarray], sep: np.ndarray):
|
||||
full_array = []
|
||||
for array in arrays:
|
||||
@@ -376,7 +503,6 @@ class ProteinComplex:
|
||||
ent2num_max = -1
|
||||
ent2num = {}
|
||||
total_index = 0
|
||||
chain_boundaries = []
|
||||
for i, c in enumerate(chains):
|
||||
num_res = c.residue_index.shape[0]
|
||||
if c.chain_id not in chain2num:
|
||||
@@ -401,7 +527,6 @@ class ProteinComplex:
|
||||
}
|
||||
)
|
||||
|
||||
chain_boundaries.append((total_index, total_index + num_res))
|
||||
total_index += num_res + 1
|
||||
|
||||
sep = np.array([-1])
|
||||
@@ -412,20 +537,25 @@ class ProteinComplex:
|
||||
array_args.update(update)
|
||||
|
||||
metadata = ProteinComplexMetadata(
|
||||
chain_boundaries=chain_boundaries,
|
||||
mmcif=mmcif,
|
||||
chain_lookup={v: k for k, v in chain2num.items()},
|
||||
entity_lookup={v: k for k, v in ent2num.items()},
|
||||
assembly_composition=all_assembly_metadata_dictionary,
|
||||
)
|
||||
|
||||
return cls(
|
||||
id=chains[0].id,
|
||||
sequence=esm3_c.CHAIN_BREAK_STR.join(chain.sequence for chain in chains),
|
||||
sequence=residue_constants.CHAIN_BREAK_TOKEN.join(
|
||||
chain.sequence for chain in chains
|
||||
),
|
||||
metadata=metadata,
|
||||
**array_args,
|
||||
)
|
||||
|
||||
def infer_oxygen(self) -> ProteinComplex:
|
||||
"""Oxygen position is fixed given N, CA, C atoms. Infer it if not provided."""
|
||||
O_missing_indices = np.argwhere(np.isnan(self.atoms["O"]).any(axis=1)).squeeze()
|
||||
|
||||
O_vector = torch.tensor([0.6240, -1.0613, 0.0103], dtype=torch.float32)
|
||||
N, CA, C = torch.from_numpy(self.atoms[["N", "CA", "C"]]).float().unbind(dim=1)
|
||||
N = torch.roll(N, -3)
|
||||
@@ -437,15 +567,56 @@ class ProteinComplex:
|
||||
atom37_positions = self.atom37_positions.copy()
|
||||
atom37_mask = self.atom37_mask.copy()
|
||||
|
||||
atom37_positions[:, residue_constants.atom_order["O"]] = O.numpy()
|
||||
atom37_mask[:, residue_constants.atom_order["O"]] = ~np.isnan(
|
||||
atom37_positions[:, residue_constants.atom_order["O"]]
|
||||
atom37_positions[O_missing_indices, residue_constants.atom_order["O"]] = O[
|
||||
O_missing_indices
|
||||
].numpy()
|
||||
atom37_mask[O_missing_indices, residue_constants.atom_order["O"]] = ~np.isnan(
|
||||
atom37_positions[O_missing_indices, residue_constants.atom_order["O"]]
|
||||
).any(-1)
|
||||
new_chain = replace(
|
||||
self, atom37_positions=atom37_positions, atom37_mask=atom37_mask
|
||||
)
|
||||
return new_chain
|
||||
|
||||
def infer_cbeta(self, infer_cbeta_for_glycine: bool = False) -> ProteinComplex:
|
||||
"""Return a new chain with inferred CB atoms at all residues except GLY.
|
||||
|
||||
Args:
|
||||
infer_cbeta_for_glycine (bool): If True, infers a beta carbon for glycine
|
||||
residues, even though that residue doesn't have one. Default off.
|
||||
|
||||
NOTE(rverkuil): The reason for having this switch in the first place
|
||||
is that sometimes we want a (inferred) CB coordinate for every residue,
|
||||
for example for making a pairwise distance matrix, or doing an RMSD
|
||||
calculation between two designs for a given structural template, w/
|
||||
CB atoms.
|
||||
"""
|
||||
atom37_positions = self.atom37_positions.copy()
|
||||
atom37_mask = self.atom37_mask.copy()
|
||||
|
||||
N, CA, C = np.moveaxis(self.atoms[["N", "CA", "C"]], 1, 0)
|
||||
# See usage in trDesign codebase.
|
||||
# https://github.com/gjoni/trDesign/blob/f2d5930b472e77bfacc2f437b3966e7a708a8d37/02-GD/utils.py#L140
|
||||
inferred_cbeta_positions = infer_CB(C, N, CA, 1.522, 1.927, -2.143)
|
||||
if not infer_cbeta_for_glycine:
|
||||
inferred_cbeta_positions[np.array(list(self.sequence)) == "G", :] = np.nan
|
||||
|
||||
atom37_positions[:, residue_constants.atom_order["CB"]] = (
|
||||
inferred_cbeta_positions
|
||||
)
|
||||
atom37_mask[:, residue_constants.atom_order["CB"]] = ~np.isnan(
|
||||
atom37_positions[:, residue_constants.atom_order["CB"]]
|
||||
).any(-1)
|
||||
new_chain = replace(
|
||||
self, atom37_positions=atom37_positions, atom37_mask=atom37_mask
|
||||
)
|
||||
return new_chain
|
||||
|
||||
@classmethod
|
||||
def from_open_source(cls, pc: ProteinComplex):
|
||||
# TODO(@zeming): deprecated, should delete
|
||||
return pc
|
||||
|
||||
@classmethod
|
||||
def concat(cls, objs: list[ProteinComplex]) -> ProteinComplex:
|
||||
pdb_ids = [obj.id for obj in objs]
|
||||
@@ -463,6 +634,50 @@ class ProteinComplex:
|
||||
list(other.chain_iter())
|
||||
), "Protein complexes must have the same number of chains"
|
||||
|
||||
def rmsd(
|
||||
self,
|
||||
target: ProteinComplex,
|
||||
also_check_reflection: bool = False,
|
||||
mobile_inds: list[int] | np.ndarray | None = None,
|
||||
target_inds: list[int] | np.ndarray | None = None,
|
||||
only_compute_backbone_rmsd: bool = False,
|
||||
compute_chain_assignment: bool = True,
|
||||
):
|
||||
"""
|
||||
Compute the RMSD between this protein chain and another.
|
||||
|
||||
Args:
|
||||
target (ProteinComplex): The target (other) protein complex to compare to.
|
||||
also_check_reflection (bool, optional): If True, also check if the reflection of the mobile atoms has a lower RMSD.
|
||||
mobile_inds (list[int], optional): The indices of the mobile atoms to align. These are NOT residue indices
|
||||
target_inds (list[int], optional): The indices of the target atoms to align. These are NOT residue indices
|
||||
only_compute_backbone_rmsd (bool, optional): If True, only compute the RMSD of the backbone atoms.
|
||||
"""
|
||||
if compute_chain_assignment:
|
||||
aligned = self.dockq(target).aligned
|
||||
else:
|
||||
aligned = self
|
||||
|
||||
aligner = Aligner(
|
||||
aligned if mobile_inds is None else aligned[mobile_inds],
|
||||
target if target_inds is None else target[target_inds],
|
||||
only_compute_backbone_rmsd,
|
||||
)
|
||||
avg_rmsd = aligner.rmsd
|
||||
|
||||
if not also_check_reflection:
|
||||
return avg_rmsd
|
||||
|
||||
aligner = Aligner(
|
||||
aligned if mobile_inds is None else aligned[mobile_inds],
|
||||
target if target_inds is None else target[target_inds],
|
||||
only_compute_backbone_rmsd,
|
||||
use_reflection=True,
|
||||
)
|
||||
avg_rmsd_neg = aligner.rmsd
|
||||
|
||||
return min(avg_rmsd, avg_rmsd_neg)
|
||||
|
||||
def lddt_ca(
|
||||
self,
|
||||
target: ProteinComplex,
|
||||
@@ -537,6 +752,10 @@ class ProteinComplex:
|
||||
# This function uses dockqv2 to compute the DockQ score. Because it does a mapping
|
||||
# over all possible chains, it's quite slow. Be careful not to use this in an inference loop
|
||||
# or something that requires fast scoring. It defaults to 8 CPUs.
|
||||
#
|
||||
# TODO(@zeming): Because we haven't properly implemented protein complexes for mmcif,
|
||||
# if your protein has multi-letter or repeated chain IDs, this will fail. Please call
|
||||
# pc = pc.normalize_chain_ids_for_pdb() before calling this function in that case (limit is 62 chains)
|
||||
|
||||
try:
|
||||
pass
|
||||
@@ -658,3 +877,316 @@ class ProteinComplex:
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@cached_property
|
||||
def per_chain_kd_trees(self):
|
||||
# Iterate over chains, build KDTree for each chain
|
||||
kdtrees = []
|
||||
|
||||
CA = self.atoms["CA"]
|
||||
|
||||
for start, end in self.chain_boundaries:
|
||||
chain_CA = CA[start:end]
|
||||
chain_CA = chain_CA[np.isfinite(chain_CA).all(axis=-1)]
|
||||
kdtrees.append(KDTree(chain_CA))
|
||||
|
||||
return kdtrees
|
||||
|
||||
def chain_adjacency(self, cutoff: float = 8.0) -> np.ndarray:
|
||||
# Compute adjacency matrix for protein complex
|
||||
num_chains = self.num_chains
|
||||
adjacency = np.zeros((num_chains, num_chains), dtype=bool)
|
||||
for (i, kdtree), (j, kdtree2) in itertools.combinations(
|
||||
enumerate(self.per_chain_kd_trees), 2
|
||||
):
|
||||
adj = kdtree.query_ball_tree(kdtree2, cutoff)
|
||||
any_is_adjacent = any(len(a) > 0 for a in adj)
|
||||
adjacency[i, j] = any_is_adjacent
|
||||
adjacency[j, i] = any_is_adjacent
|
||||
return adjacency
|
||||
|
||||
def chain_adjacency_by_index(self, index: int, cutoff: float = 8.0) -> np.ndarray:
|
||||
num_chains = len(self.chain_boundaries)
|
||||
adjacency = np.zeros(num_chains, dtype=bool)
|
||||
for i, kdtree in enumerate(self.per_chain_kd_trees):
|
||||
if i == index:
|
||||
continue
|
||||
adj = kdtree.query_ball_tree(self.per_chain_kd_trees[index], cutoff)
|
||||
adjacency[i] = any(len(a) > 0 for a in adj)
|
||||
return adjacency
|
||||
|
||||
def add_prefix_to_chain_ids(self, prefix: str) -> ProteinComplex:
|
||||
"""Rename all chains in the complex with a given prefix.
|
||||
|
||||
Args:
|
||||
prefix (str): The prefix to use for the new chain IDs. Each chain will be
|
||||
named as "{prefix}_{chain_id}".
|
||||
|
||||
Returns:
|
||||
ProteinComplex: A new protein complex with renamed chains.
|
||||
"""
|
||||
new_chains = []
|
||||
for chain in self.chain_iter():
|
||||
# Create new chain with updated chain_id
|
||||
new_chain = replace(chain, chain_id=f"{prefix}_{chain.chain_id}")
|
||||
new_chains.append(new_chain)
|
||||
return ProteinComplex.from_chains(new_chains)
|
||||
|
||||
def sasa(self, by_residue: bool = True):
|
||||
chain = self.as_chain(force_conversion=True)
|
||||
return chain.sasa(by_residue=by_residue)
|
||||
|
||||
def to_mmcif_string(self) -> str:
|
||||
"""Convert the ProteinComplex to mmCIF format.
|
||||
|
||||
Returns:
|
||||
str: The mmCIF content as a string.
|
||||
"""
|
||||
# Convert the ProteinComplex to a biotite AtomArray
|
||||
# Collect all atoms from all chains
|
||||
all_atoms = []
|
||||
for chain in self.chain_iter():
|
||||
chain_atom_array = chain.atom_array
|
||||
# Convert AtomArray to list of atoms and add to collection
|
||||
all_atoms.extend(chain_atom_array)
|
||||
|
||||
# Create combined AtomArray from all atoms
|
||||
if not all_atoms:
|
||||
raise ValueError("No atoms found in protein complex")
|
||||
|
||||
atom_array = bs.array(all_atoms)
|
||||
|
||||
# Create CIF file
|
||||
f = CIFFile()
|
||||
set_structure_pdbx(f, atom_array, data_block=self.id)
|
||||
|
||||
# Add entity information for proper mmCIF structure
|
||||
self._add_entity_information(f)
|
||||
|
||||
# Write to string
|
||||
output = io.StringIO()
|
||||
f.write(output)
|
||||
return output.getvalue()
|
||||
|
||||
def _add_entity_information(self, cif_file: CIFFile) -> None:
|
||||
"""Add entity, entity_poly, and struct_asym sections to CIF file."""
|
||||
|
||||
# Group chains by sequence to create unique entities
|
||||
entity_map = {} # sequence -> entity_id
|
||||
chain_to_entity = {} # chain_id -> entity_id
|
||||
entity_sequences = {} # entity_id -> sequence
|
||||
entity_id_counter = 1
|
||||
|
||||
for chain in self.chain_iter():
|
||||
sequence = chain.sequence
|
||||
if sequence not in entity_map:
|
||||
entity_map[sequence] = entity_id_counter
|
||||
entity_sequences[entity_id_counter] = sequence
|
||||
entity_id_counter += 1
|
||||
chain_to_entity[chain.chain_id] = entity_map[sequence]
|
||||
|
||||
# Create _entity section
|
||||
entity_ids = []
|
||||
entity_types = []
|
||||
entity_descriptions = []
|
||||
|
||||
for entity_id in sorted(entity_sequences.keys()):
|
||||
entity_ids.append(str(entity_id))
|
||||
entity_types.append("polymer")
|
||||
entity_descriptions.append(f"Protein chain (entity {entity_id})")
|
||||
|
||||
cif_file.block["entity"] = CIFCategory(
|
||||
name="entity",
|
||||
columns={
|
||||
"id": CIFColumn(
|
||||
data=CIFData(array=np.array(entity_ids), dtype=np.str_)
|
||||
),
|
||||
"type": CIFColumn(
|
||||
data=CIFData(array=np.array(entity_types), dtype=np.str_)
|
||||
),
|
||||
"pdbx_description": CIFColumn(
|
||||
data=CIFData(array=np.array(entity_descriptions), dtype=np.str_)
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
# Create _entity_poly section
|
||||
poly_entity_ids = []
|
||||
poly_types = []
|
||||
poly_nstd_linkages = []
|
||||
poly_sequences = []
|
||||
|
||||
for entity_id in sorted(entity_sequences.keys()):
|
||||
poly_entity_ids.append(str(entity_id))
|
||||
poly_types.append("polypeptide(L)")
|
||||
poly_nstd_linkages.append("no")
|
||||
poly_sequences.append(entity_sequences[entity_id])
|
||||
|
||||
cif_file.block["entity_poly"] = CIFCategory(
|
||||
name="entity_poly",
|
||||
columns={
|
||||
"entity_id": CIFColumn(
|
||||
data=CIFData(array=np.array(poly_entity_ids), dtype=np.str_)
|
||||
),
|
||||
"type": CIFColumn(
|
||||
data=CIFData(array=np.array(poly_types), dtype=np.str_)
|
||||
),
|
||||
"nstd_linkage": CIFColumn(
|
||||
data=CIFData(array=np.array(poly_nstd_linkages), dtype=np.str_)
|
||||
),
|
||||
"pdbx_seq_one_letter_code": CIFColumn(
|
||||
data=CIFData(array=np.array(poly_sequences), dtype=np.str_)
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
# Create _struct_asym section
|
||||
asym_ids = []
|
||||
asym_entity_ids = []
|
||||
asym_details = []
|
||||
|
||||
for chain in self.chain_iter():
|
||||
asym_ids.append(chain.chain_id)
|
||||
asym_entity_ids.append(str(chain_to_entity[chain.chain_id]))
|
||||
asym_details.append("")
|
||||
|
||||
cif_file.block["struct_asym"] = CIFCategory(
|
||||
name="struct_asym",
|
||||
columns={
|
||||
"id": CIFColumn(data=CIFData(array=np.array(asym_ids), dtype=np.str_)),
|
||||
"entity_id": CIFColumn(
|
||||
data=CIFData(array=np.array(asym_entity_ids), dtype=np.str_)
|
||||
),
|
||||
"details": CIFColumn(
|
||||
data=CIFData(array=np.array(asym_details), dtype=np.str_)
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def get_assembly_fast(
|
||||
mmcif: MmcifWrapper,
|
||||
assembly_id=None,
|
||||
model=None,
|
||||
data_block=None,
|
||||
altloc="first",
|
||||
use_author_fields=True,
|
||||
):
|
||||
pdbx_file = mmcif.raw
|
||||
if pdbx_file is None:
|
||||
raise InvalidFileError("No mmCIF data loaded")
|
||||
assembly_gen_category = pdbx_file.block["pdbx_struct_assembly_gen"]
|
||||
if assembly_gen_category is None:
|
||||
raise InvalidFileError("File has no 'pdbx_struct_assembly_gen' category")
|
||||
|
||||
struct_oper_category = pdbx_file.block["pdbx_struct_oper_list"]
|
||||
if struct_oper_category is None:
|
||||
raise InvalidFileError("File has no 'pdbx_struct_oper_list' category")
|
||||
|
||||
if assembly_id is None:
|
||||
assembly_id = assembly_gen_category["assembly_id"].data.array[0]
|
||||
elif assembly_id not in assembly_gen_category["assembly_id"].data.array:
|
||||
raise KeyError(f"File has no Assembly ID '{assembly_id}'")
|
||||
|
||||
### Calculate all possible transformations
|
||||
transformations = _get_transformations(struct_oper_category)
|
||||
|
||||
### Get structure according to additional parameters
|
||||
structure = get_structure(
|
||||
pdbx_file, model, data_block, altloc, ["label_asym_id"], use_author_fields
|
||||
)[0] # type: ignore
|
||||
# TODO(@zeming) This line will remove all non-protein structural elements,
|
||||
# we should remove this when we want to parse these too.
|
||||
structure: bs.AtomArray = structure[
|
||||
bs.filter_amino_acids(structure) & ~structure.hetero # type: ignore
|
||||
]
|
||||
if len(structure) == 0:
|
||||
raise NoProteinError
|
||||
unique_asym_ids = np.unique(structure.label_asym_id) # type: ignore
|
||||
asym2chain = {}
|
||||
asym2auth = {}
|
||||
for asym_id in unique_asym_ids:
|
||||
sub_structure: bs.AtomArray = structure[structure.label_asym_id == asym_id] # type: ignore
|
||||
chain_id: str = sub_structure[0].chain_id # type: ignore
|
||||
(
|
||||
sequence,
|
||||
atom_positions,
|
||||
atom_mask,
|
||||
residue_index,
|
||||
insertion_code,
|
||||
confidence,
|
||||
entity_id,
|
||||
) = chain_to_ndarray(sub_structure, mmcif, chain_id, False)
|
||||
|
||||
asym2chain[asym_id] = ProteinChain(
|
||||
id=mmcif.id or "unknown",
|
||||
sequence=sequence,
|
||||
chain_id=chain_id,
|
||||
entity_id=entity_id,
|
||||
atom37_positions=atom_positions,
|
||||
atom37_mask=atom_mask,
|
||||
residue_index=residue_index,
|
||||
insertion_code=insertion_code,
|
||||
confidence=confidence,
|
||||
mmcif=None,
|
||||
)
|
||||
asym2auth[asym_id] = chain_id
|
||||
|
||||
### Get transformations and apply them to the affected asym IDs
|
||||
assembly = []
|
||||
assembly_id_dict: dict[str, list[str]] = {}
|
||||
|
||||
# Process the target assembly ID
|
||||
for aid, op_expr, asym_id_expr in zip(
|
||||
assembly_gen_category["assembly_id"].data.array,
|
||||
assembly_gen_category["oper_expression"].data.array,
|
||||
assembly_gen_category["asym_id_list"].data.array,
|
||||
):
|
||||
if aid == assembly_id:
|
||||
# Parse operations and asym IDs for this specific entry
|
||||
operations = _parse_operation_expression(op_expr)
|
||||
asym_ids = asym_id_expr.split(",")
|
||||
|
||||
# Filter affected asym IDs to only protein chains, preserving order
|
||||
sub_structures = [
|
||||
asym2chain[asym_id] for asym_id in asym_ids if asym_id in asym2chain
|
||||
]
|
||||
|
||||
# Apply transformations
|
||||
sub_assembly = _apply_transformations_fast(
|
||||
sub_structures, transformations, operations
|
||||
)
|
||||
assembly.extend(sub_assembly)
|
||||
|
||||
# Build assembly_id_dict for this entry
|
||||
assembly_id_dict[aid] = assembly_id_dict.get(aid, []) + [
|
||||
asym2auth[id_] for id_ in asym_ids if id_ in asym2auth
|
||||
]
|
||||
|
||||
if len(assembly) == 0:
|
||||
raise NoProteinError
|
||||
return ProteinComplex.from_chains(assembly, mmcif, assembly_id_dict)
|
||||
|
||||
|
||||
def protein_chain_to_protein_complex(chain: ProteinChain) -> ProteinComplex:
|
||||
if "|" not in chain.sequence:
|
||||
return ProteinComplex.from_chains([chain])
|
||||
chain_breaks = np.array(list(chain.sequence)) == "|"
|
||||
chain_break_inds = np.where(chain_breaks)[0]
|
||||
chain_break_inds = np.concatenate([[0], chain_break_inds, [len(chain)]])
|
||||
chain_break_inds = np.array(list(zip(chain_break_inds[:-1], chain_break_inds[1:])))
|
||||
complex_chains = []
|
||||
for start, end in chain_break_inds:
|
||||
if start != 0:
|
||||
start += 1
|
||||
complex_chains.append(chain[start:end])
|
||||
complex_chains = [
|
||||
ProteinChain.from_atom37(
|
||||
chain.atom37_positions,
|
||||
sequence=chain.sequence,
|
||||
chain_id=SINGLE_LETTER_CHAIN_IDS[i],
|
||||
entity_id=i,
|
||||
)
|
||||
for i, chain in enumerate(complex_chains)
|
||||
]
|
||||
return ProteinComplex.from_chains(complex_chains)
|
||||
|
||||
@@ -194,7 +194,7 @@ def compute_rmsd_no_alignment(
|
||||
mean_squared_error = diff.square().view(diff.size(0), -1, 9).mean(dim=-1)
|
||||
else:
|
||||
mean_squared_error = diff.square().sum(dim=(1, 2)) / (
|
||||
num_valid_atoms.squeeze(-1) * 3
|
||||
num_valid_atoms.squeeze(-1)
|
||||
)
|
||||
|
||||
rmsd = torch.sqrt(mean_squared_error)
|
||||
|
||||
@@ -4,9 +4,7 @@ import pygtrie
|
||||
from ipywidgets import widgets
|
||||
|
||||
from esm.sdk.api import FunctionAnnotation
|
||||
from esm.tokenization.function_tokenizer import (
|
||||
InterProQuantizedTokenizer,
|
||||
)
|
||||
from esm.tokenization.function_tokenizer import InterProQuantizedTokenizer
|
||||
|
||||
TRIE: pygtrie.CharTrie | None = None
|
||||
|
||||
|
||||
@@ -7,15 +7,11 @@ import matplotlib.colors as mcolors
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from esm.sdk.api import ESMProtein
|
||||
from esm.widgets.utils.drawing.draw_category_array import (
|
||||
draw_data_array,
|
||||
)
|
||||
from esm.widgets.utils.drawing.draw_category_array import draw_data_array
|
||||
from esm.widgets.utils.drawing.draw_function_annotations import (
|
||||
draw_function_annotations,
|
||||
)
|
||||
from esm.widgets.utils.drawing.draw_protein_structure import (
|
||||
draw_protein_structure,
|
||||
)
|
||||
from esm.widgets.utils.drawing.draw_protein_structure import draw_protein_structure
|
||||
from esm.widgets.utils.serialization import (
|
||||
create_download_button_from_buffer,
|
||||
protein_to_pdb_buffer,
|
||||
|
||||
@@ -3,16 +3,9 @@ from typing import Any, Callable, Sequence
|
||||
import ipywidgets as widgets
|
||||
|
||||
from esm.utils.structure.protein_chain import ProteinChain
|
||||
from esm.widgets.utils.drawing.colors import (
|
||||
hex_to_rgba_tuple,
|
||||
rgba_tuple_to_hex,
|
||||
)
|
||||
from esm.widgets.utils.drawing.draw_category_array import (
|
||||
draw_data_array,
|
||||
)
|
||||
from esm.widgets.utils.parsing import (
|
||||
convert_range_string_to_list_of_ranges,
|
||||
)
|
||||
from esm.widgets.utils.drawing.colors import hex_to_rgba_tuple, rgba_tuple_to_hex
|
||||
from esm.widgets.utils.drawing.draw_category_array import draw_data_array
|
||||
from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges
|
||||
from esm.widgets.utils.prompting import PromptManager
|
||||
|
||||
|
||||
|
||||
@@ -4,16 +4,9 @@ import ipywidgets as widgets
|
||||
import pydssp
|
||||
|
||||
from esm.utils.structure.protein_chain import ProteinChain
|
||||
from esm.widgets.utils.drawing.colors import (
|
||||
hex_to_rgba_tuple,
|
||||
rgba_tuple_to_hex,
|
||||
)
|
||||
from esm.widgets.utils.drawing.draw_category_array import (
|
||||
draw_data_array,
|
||||
)
|
||||
from esm.widgets.utils.parsing import (
|
||||
convert_range_string_to_list_of_ranges,
|
||||
)
|
||||
from esm.widgets.utils.drawing.colors import hex_to_rgba_tuple, rgba_tuple_to_hex
|
||||
from esm.widgets.utils.drawing.draw_category_array import draw_data_array
|
||||
from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges
|
||||
from esm.widgets.utils.prompting import PromptManager
|
||||
|
||||
|
||||
|
||||
@@ -6,9 +6,7 @@ from esm.widgets.utils.drawing.colors import (
|
||||
hex_to_rgba_tuple,
|
||||
rgba_tuple_to_rgba_html_string,
|
||||
)
|
||||
from esm.widgets.utils.parsing import (
|
||||
convert_range_string_to_list_of_ranges,
|
||||
)
|
||||
from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges
|
||||
from esm.widgets.utils.prompting import PromptManager
|
||||
|
||||
|
||||
|
||||
@@ -10,12 +10,8 @@ from matplotlib.patches import Rectangle
|
||||
|
||||
from esm.utils.structure.protein_chain import ProteinChain
|
||||
from esm.widgets.utils import indexing
|
||||
from esm.widgets.utils.drawing.draw_protein_structure import (
|
||||
draw_protein_structure,
|
||||
)
|
||||
from esm.widgets.utils.parsing import (
|
||||
convert_range_string_to_list_of_ranges,
|
||||
)
|
||||
from esm.widgets.utils.drawing.draw_protein_structure import draw_protein_structure
|
||||
from esm.widgets.utils.parsing import convert_range_string_to_list_of_ranges
|
||||
from esm.widgets.utils.printing import wrapped_print
|
||||
from esm.widgets.utils.prompting import PromptManager
|
||||
|
||||
|
||||
@@ -9,10 +9,7 @@ from matplotlib import colormaps
|
||||
from PIL import Image
|
||||
|
||||
from esm.sdk.api import FunctionAnnotation
|
||||
from esm.utils.function.interpro import (
|
||||
InterPro,
|
||||
InterProEntryType,
|
||||
)
|
||||
from esm.utils.function.interpro import InterPro, InterProEntryType
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
||||
@@ -9,9 +9,7 @@ from esm.sdk.api import ESMProtein, FunctionAnnotation
|
||||
from esm.utils import encoding
|
||||
from esm.widgets.utils import indexing
|
||||
from esm.widgets.utils.drawing.colors import rgba_tuple_to_hex
|
||||
from esm.widgets.utils.drawing.draw_category_array import (
|
||||
draw_data_array,
|
||||
)
|
||||
from esm.widgets.utils.drawing.draw_category_array import draw_data_array
|
||||
from esm.widgets.utils.printing import wrapped_print
|
||||
|
||||
|
||||
|
||||
@@ -13,13 +13,9 @@ from esm.sdk.api import (
|
||||
GenerationConfig,
|
||||
)
|
||||
from esm.utils.constants import models
|
||||
from esm.widgets.components.results_visualizer import (
|
||||
create_results_visualizer,
|
||||
)
|
||||
from esm.widgets.components.results_visualizer import create_results_visualizer
|
||||
from esm.widgets.utils.printing import wrapped_print
|
||||
from esm.widgets.utils.serialization import (
|
||||
create_download_results_button,
|
||||
)
|
||||
from esm.widgets.utils.serialization import create_download_results_button
|
||||
|
||||
|
||||
def create_esm3_generation_launcher(
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from ipywidgets import widgets
|
||||
|
||||
from esm.widgets.components.sasa_prompt_selector import (
|
||||
create_sasa_prompt_selector,
|
||||
)
|
||||
from esm.widgets.components.sasa_prompt_selector import create_sasa_prompt_selector
|
||||
from esm.widgets.components.secondary_structure_prompt_selector import (
|
||||
create_secondary_structure_prompt_selector,
|
||||
)
|
||||
|
||||
@@ -4,20 +4,12 @@ from ipywidgets import widgets
|
||||
|
||||
from esm.sdk.api import ESM3InferenceClient, ESMProtein
|
||||
from esm.utils.constants import esm3 as C
|
||||
from esm.widgets.components.function_annotator import (
|
||||
create_function_annotator,
|
||||
)
|
||||
from esm.widgets.components.function_annotator import create_function_annotator
|
||||
from esm.widgets.utils.prompting import PromptManagerCollection
|
||||
from esm.widgets.utils.protein_import import ProteinImporter
|
||||
from esm.widgets.views.esm3_generation_launcher import (
|
||||
create_esm3_generation_launcher,
|
||||
)
|
||||
from esm.widgets.views.esm3_prompt_preview import (
|
||||
create_esm3_prompt_preview,
|
||||
)
|
||||
from esm.widgets.views.esm3_prompt_selector import (
|
||||
create_esm3_prompt_selector,
|
||||
)
|
||||
from esm.widgets.views.esm3_generation_launcher import create_esm3_generation_launcher
|
||||
from esm.widgets.views.esm3_prompt_preview import create_esm3_prompt_preview
|
||||
from esm.widgets.views.esm3_prompt_selector import create_esm3_prompt_selector
|
||||
|
||||
|
||||
def create_generation_ui(
|
||||
|
||||
@@ -6,9 +6,7 @@ from esm.sdk.api import (
|
||||
ESMProteinError,
|
||||
GenerationConfig,
|
||||
)
|
||||
from esm.widgets.components.results_visualizer import (
|
||||
create_results_visualizer,
|
||||
)
|
||||
from esm.widgets.components.results_visualizer import create_results_visualizer
|
||||
from esm.widgets.utils.printing import wrapped_print
|
||||
from esm.widgets.utils.protein_import import ProteinImporter
|
||||
|
||||
|
||||
@@ -4,10 +4,7 @@ from textwrap import dedent
|
||||
|
||||
from ipywidgets import widgets
|
||||
|
||||
from esm.widgets.utils.clients import (
|
||||
get_forge_client,
|
||||
get_local_client,
|
||||
)
|
||||
from esm.widgets.utils.clients import get_forge_client, get_local_client
|
||||
from esm.widgets.utils.types import ClientInitContainer
|
||||
|
||||
|
||||
|
||||
@@ -6,9 +6,7 @@ from esm.sdk.api import (
|
||||
ESMProteinError,
|
||||
GenerationConfig,
|
||||
)
|
||||
from esm.widgets.components.results_visualizer import (
|
||||
create_results_visualizer,
|
||||
)
|
||||
from esm.widgets.components.results_visualizer import create_results_visualizer
|
||||
from esm.widgets.utils.printing import wrapped_print
|
||||
from esm.widgets.utils.protein_import import ProteinImporter
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ name = "esm"
|
||||
version = "3.2.1"
|
||||
description = "EvolutionaryScale open model repository"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
requires-python = ">=3.12,<3.13"
|
||||
license = {file = "LICENSE.txt"}
|
||||
|
||||
authors = [
|
||||
@@ -17,7 +17,7 @@ maintainers = [
|
||||
classifiers = [
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Topic :: Scientific/Engineering :: Bio-Informatics",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
]
|
||||
|
||||
dependencies = [
|
||||
@@ -35,9 +35,25 @@ dependencies = [
|
||||
"attrs",
|
||||
"pandas",
|
||||
"cloudpathlib",
|
||||
"httpx",
|
||||
"tenacity",
|
||||
"zstd"
|
||||
"zstd",
|
||||
"ipywidgets",
|
||||
"py3dmol",
|
||||
"pydssp",
|
||||
"boto3",
|
||||
"pygtrie",
|
||||
"dna_features_viewer",
|
||||
"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.1/flash_attn-2.8.1+cu12torch2.7cxx11abiFALSE-cp312-cp312-linux_x86_64.whl"
|
||||
]
|
||||
# Pytest
|
||||
[tool.pytest.ini_options]
|
||||
addopts = """
|
||||
--cov=esm
|
||||
--cov-report term-missing:skip-covered
|
||||
-n auto
|
||||
--ignore=tests/oss_pytests/test_oss_client.py
|
||||
"""
|
||||
|
||||
[tool.setuptools]
|
||||
package-dir = {"" = "."}
|
||||
@@ -52,7 +68,7 @@ esm = ["data/*"]
|
||||
|
||||
[tool.pixi.project]
|
||||
channels = ["conda-forge"]
|
||||
platforms = ["linux-64"]
|
||||
platforms = ["linux-64", "osx-arm64"]
|
||||
|
||||
# These are build dependencies, to ensure pip support, keep run-time deps above in `dependencies`
|
||||
[tool.pixi.dependencies]
|
||||
@@ -60,6 +76,7 @@ pkg-config = "*"
|
||||
cmake = "*"
|
||||
pip = "*"
|
||||
twine = "*"
|
||||
python = "3.12.*"
|
||||
|
||||
[tool.pixi.pypi-dependencies]
|
||||
esm = { path = ".", editable = true }
|
||||
@@ -67,3 +84,61 @@ esm = { path = ".", editable = true }
|
||||
[tool.pixi.tasks]
|
||||
build-wheel = "python -m pip wheel --no-deps -w dist ."
|
||||
upload-wheel = "python -m twine upload --repository pypi"
|
||||
|
||||
[tool.pixi.feature.dev.dependencies]
|
||||
matplotlib = "*"
|
||||
pre-commit = "*"
|
||||
pytest = "*"
|
||||
pytest-cov = "*"
|
||||
pytest-xdist = "*"
|
||||
seaborn = "*"
|
||||
pyright = "==1.1.399"
|
||||
|
||||
[tool.pixi.feature.dev.tasks]
|
||||
lint-all = "pre-commit run --all-files --show-diff-on-failure"
|
||||
cov-test = "pytest -v --junitxml=pytest.xml --cov=esm"
|
||||
|
||||
[tool.pixi.environments]
|
||||
default = {features = [], solve-group = "default"}
|
||||
dev = {features = ["dev"], solve-group = "default"}
|
||||
|
||||
[tool.ruff]
|
||||
extend-include = ["*.ipynb"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`), sort imports ('I')
|
||||
select = ["E4", "E7", "E9", "F", "I"]
|
||||
ignore = [
|
||||
# allow variable == False (tensors should do this)
|
||||
"E712",
|
||||
# allow assigning of lambdas
|
||||
"E731",
|
||||
# Allow ambiguous variables, e.g. we use O for oxygen
|
||||
"E741",
|
||||
# Ignore errors from jaxtyping hints
|
||||
# https://docs.kidger.site/jaxtyping/faq/#flake8-or-ruff-are-throwing-an-error
|
||||
"F722",
|
||||
# TODO: Fix the few offenders in a follow up PR
|
||||
"E721",
|
||||
]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
split-on-trailing-comma = false
|
||||
known-third-party = ["wandb"]
|
||||
|
||||
[tool.ruff.format]
|
||||
skip-magic-trailing-comma = true
|
||||
docstring-code-format = true
|
||||
docstring-code-line-length = "dynamic"
|
||||
|
||||
[tool.isort]
|
||||
known_third_party = ["wandb"]
|
||||
|
||||
[tool.pyright]
|
||||
root = ['.']
|
||||
useLibraryCodeForTypes = true
|
||||
reportPrivateImportUsage = false
|
||||
typeCheckingMode = "basic"
|
||||
|
||||
[tool.importlinter]
|
||||
root_package = "esm"
|
||||
|
||||
15
tests/Makefile
Normal file
15
tests/Makefile
Normal file
@@ -0,0 +1,15 @@
|
||||
# OSS-specific variables and commands
|
||||
DOCKER_TAG ?= dev
|
||||
DOCKER_IMAGE_OSS=oss_pytests:${DOCKER_TAG}
|
||||
|
||||
build-oss-ci:
|
||||
docker build -f oss_pytests/Dockerfile oss_pytests -t $(DOCKER_IMAGE_OSS)
|
||||
|
||||
start-docker-oss:
|
||||
docker run \
|
||||
--rm \
|
||||
-e URL=${URL} \
|
||||
-e ESM3_FORGE_TOKEN=${ESM3_FORGE_TOKEN} \
|
||||
--name=$(USER)-oss_pytests \
|
||||
--network=host \
|
||||
${DOCKER_IMAGE_OSS}
|
||||
19
tests/oss_pytests/Dockerfile
Normal file
19
tests/oss_pytests/Dockerfile
Normal file
@@ -0,0 +1,19 @@
|
||||
# Dockerfile.sdktest
|
||||
FROM python:3.12-slim
|
||||
|
||||
# Install pip and basic dependencies
|
||||
RUN apt-get update && apt-get install -y curl build-essential && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /tests
|
||||
|
||||
# Copy requirements and install them (assumes esm-oss is one of them)
|
||||
COPY requirements.txt .
|
||||
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy in the tests
|
||||
COPY . .
|
||||
|
||||
# Default command (can be overridden in docker run)
|
||||
CMD ["pytest", "-v", "test_oss_client.py"]
|
||||
3
tests/oss_pytests/requirements.txt
Normal file
3
tests/oss_pytests/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
esm
|
||||
pytest
|
||||
httpx # TODO(williamxi): Remove this after the esm repo is fixed
|
||||
85
tests/oss_pytests/test_oss_client.py
Normal file
85
tests/oss_pytests/test_oss_client.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from esm.sdk import client # pyright: ignore
|
||||
from esm.sdk.api import ( # pyright: ignore
|
||||
ESMProtein,
|
||||
ESMProteinTensor,
|
||||
ForwardAndSampleOutput,
|
||||
GenerationConfig,
|
||||
LogitsConfig,
|
||||
LogitsOutput,
|
||||
SamplingConfig,
|
||||
SamplingTrackConfig,
|
||||
)
|
||||
from esm.sdk.forge import SequenceStructureForgeInferenceClient # pyright: ignore
|
||||
|
||||
API_TOKEN = os.environ.get("ESM3_FORGE_TOKEN", "")
|
||||
URL = os.environ.get("URL")
|
||||
|
||||
|
||||
@pytest.mark.sdk
|
||||
def test_oss_esm3_client():
|
||||
assert URL is not None
|
||||
|
||||
sequence = "MALWMRLLPLLALLAL___PDPAAA"
|
||||
model = "esm3-small-2024-03"
|
||||
esm3_client = client(model=model, url=URL, token=API_TOKEN)
|
||||
protein = ESMProtein(sequence)
|
||||
|
||||
encoded_protein = esm3_client.encode(input=protein)
|
||||
assert isinstance(encoded_protein, ESMProteinTensor)
|
||||
|
||||
decoded_protein = esm3_client.decode(input=encoded_protein)
|
||||
assert isinstance(decoded_protein, ESMProtein)
|
||||
|
||||
logits_config = LogitsConfig(sequence=True, return_embeddings=True)
|
||||
result = esm3_client.logits(input=encoded_protein, config=logits_config)
|
||||
assert isinstance(result, LogitsOutput)
|
||||
|
||||
sampling_config = SamplingConfig(sequence=SamplingTrackConfig(temperature=0.1))
|
||||
result = esm3_client.forward_and_sample(
|
||||
input=encoded_protein, sampling_configuration=sampling_config
|
||||
)
|
||||
assert isinstance(result, ForwardAndSampleOutput)
|
||||
|
||||
generation_config = GenerationConfig(track="sequence", num_steps=4)
|
||||
result = esm3_client.generate(input=protein, config=generation_config)
|
||||
assert isinstance(result, ESMProtein)
|
||||
|
||||
|
||||
@pytest.mark.sdk
|
||||
def test_oss_esmc_client():
|
||||
assert URL is not None
|
||||
|
||||
sequence = "MALWMRLLPLLALLALAVUUPDPAAA"
|
||||
model = "esmc-300m-2024-12"
|
||||
esmc_client = client(model=model, url=URL, token=API_TOKEN)
|
||||
|
||||
protein = ESMProtein(sequence)
|
||||
encoded_protein = esmc_client.encode(input=protein)
|
||||
assert isinstance(encoded_protein, ESMProteinTensor)
|
||||
|
||||
decoded_protein = esmc_client.decode(input=encoded_protein)
|
||||
assert isinstance(decoded_protein, ESMProtein)
|
||||
|
||||
logits_config = LogitsConfig(
|
||||
sequence=True, return_embeddings=True, return_hidden_states=True
|
||||
)
|
||||
result = esmc_client.logits(input=encoded_protein, config=logits_config)
|
||||
assert isinstance(result, LogitsOutput)
|
||||
|
||||
|
||||
@pytest.mark.sdk
|
||||
def test_oss_sequence_structure_forge_inference_client():
|
||||
assert URL is not None
|
||||
|
||||
sequence = "MALWMRLLPLLALLALAVUUPDPAAA"
|
||||
model = "esm3-small-2024-03"
|
||||
client = SequenceStructureForgeInferenceClient(
|
||||
model=model, url=URL, token=API_TOKEN
|
||||
)
|
||||
|
||||
encoded_protein = client.fold(sequence=sequence)
|
||||
assert isinstance(encoded_protein, ESMProtein)
|
||||
6
tests/oss_pytests/test_placeholder.py
Normal file
6
tests/oss_pytests/test_placeholder.py
Normal file
@@ -0,0 +1,6 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="no other tests in this suite")
|
||||
def test_placeholder():
|
||||
pass
|
||||
Reference in New Issue
Block a user