mirror of
https://github.com/evolutionaryscale/esm.git
synced 2026-06-04 17:14:23 +08:00
Add support for ProteinComplex (#125)
This commit is contained in:
@@ -1 +1 @@
|
||||
__version__ = "3.0.6"
|
||||
__version__ = "3.0.7"
|
||||
|
||||
@@ -14,7 +14,11 @@ from esm.tokenization import (
|
||||
)
|
||||
from esm.utils import encoding
|
||||
from esm.utils.constants.models import ESM3_OPEN_SMALL
|
||||
from esm.utils.misc import (
|
||||
get_chainbreak_boundaries_from_sequence,
|
||||
)
|
||||
from esm.utils.structure.protein_chain import ProteinChain
|
||||
from esm.utils.structure.protein_complex import ProteinComplex
|
||||
from esm.utils.types import (
|
||||
FunctionAnnotation,
|
||||
PathOrBuffer,
|
||||
@@ -94,9 +98,27 @@ class ESMProtein(ProteinType):
|
||||
coordinates=torch.tensor(protein_chain.atom37_positions),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_protein_complex(
|
||||
cls, protein_complex: ProteinComplex, with_annotations: bool = False
|
||||
) -> ESMProtein:
|
||||
if with_annotations:
|
||||
raise NotImplementedError(
|
||||
"Annotations are not supported for ProteinComplex yet."
|
||||
)
|
||||
|
||||
return ESMProtein(
|
||||
sequence=protein_complex.sequence,
|
||||
secondary_structure=None,
|
||||
sasa=None,
|
||||
function_annotations=None,
|
||||
coordinates=torch.tensor(protein_complex.atom37_positions),
|
||||
)
|
||||
|
||||
def to_pdb(self, pdb_path: PathOrBuffer) -> None:
|
||||
protein_chain = self.to_protein_chain().infer_oxygen()
|
||||
protein_chain.to_pdb(pdb_path)
|
||||
# Note: Will work for single chains as well and produce same pdb file
|
||||
protein_complex = self.to_protein_complex().infer_oxygen()
|
||||
protein_complex.to_pdb(pdb_path)
|
||||
|
||||
def to_pdb_string(self) -> str:
|
||||
protein_chain = self.to_protein_chain()
|
||||
@@ -119,6 +141,33 @@ class ESMProtein(ProteinType):
|
||||
)
|
||||
return protein_chain
|
||||
|
||||
def to_protein_complex(
|
||||
self, copy_annotations_from_ground_truth: ProteinComplex | None = None
|
||||
) -> ProteinComplex:
|
||||
assert (
|
||||
self.sequence is not None
|
||||
), "ESMProtein must have a sequence to convert to ProteinComplex"
|
||||
assert (
|
||||
self.coordinates is not None
|
||||
), "ESMProtein must have coordinates to convert to ProteinComplex"
|
||||
coords = self.coordinates.to("cpu").numpy()
|
||||
|
||||
chain_boundaries = get_chainbreak_boundaries_from_sequence(self.sequence)
|
||||
if copy_annotations_from_ground_truth is not None:
|
||||
gt_chains = list(copy_annotations_from_ground_truth.chain_iter())
|
||||
else:
|
||||
gt_chains = None
|
||||
pred_chains = []
|
||||
for i, (start, end) in enumerate(chain_boundaries):
|
||||
pred_chain = ProteinChain.from_atom37(
|
||||
atom37_positions=coords[start:end],
|
||||
sequence=self.sequence[start:end],
|
||||
chain_id=gt_chains[i].chain_id if gt_chains is not None else None,
|
||||
entity_id=gt_chains[i].entity_id if gt_chains is not None else None,
|
||||
)
|
||||
pred_chains.append(pred_chain)
|
||||
return ProteinComplex.from_chains(pred_chains)
|
||||
|
||||
|
||||
@define
|
||||
class ESMProteinTensor(ProteinType):
|
||||
|
||||
@@ -414,7 +414,7 @@ def iterative_sampling_tokens(
|
||||
# that of the prompt, which may or may not be padded, depending on
|
||||
# whether the padding was done locally with the open source model
|
||||
# (where per_prompt_cur_sampled is already padded) or by
|
||||
# BatchedForwardRunner (where per_prompt_cur_sampled is not padded).
|
||||
# BatchedESM3ModelRunner (where per_prompt_cur_sampled is not padded).
|
||||
len(per_prompt_cur_sampled),
|
||||
)
|
||||
|
||||
|
||||
@@ -2,11 +2,13 @@ import math
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import ContextManager, Sequence, TypeVar
|
||||
from warnings import warn
|
||||
|
||||
import huggingface_hub
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from esm.utils.constants.esm3 import CHAIN_BREAK_STR
|
||||
from esm.utils.types import FunctionAnnotation
|
||||
|
||||
MAX_SUPPORTED_DISTANCE = 1e6
|
||||
@@ -297,3 +299,23 @@ def huggingfacehub_login():
|
||||
variable, else by prompting the user"""
|
||||
token = os.environ.get("HF_TOKEN")
|
||||
huggingface_hub.login(token=token)
|
||||
|
||||
|
||||
def get_chainbreak_boundaries_from_sequence(sequence: Sequence[str]) -> np.ndarray:
|
||||
chain_boundaries = [0]
|
||||
for i, aa in enumerate(sequence):
|
||||
if aa == CHAIN_BREAK_STR:
|
||||
if i == (len(sequence) - 1):
|
||||
raise ValueError(
|
||||
"Encountered chain break token at end of sequence, this is unexpected."
|
||||
)
|
||||
if i == (len(sequence) - 2):
|
||||
warn(
|
||||
"Encountered chain break token at penultimate position, this is unexpected."
|
||||
)
|
||||
chain_boundaries.append(i)
|
||||
chain_boundaries.append(i + 1)
|
||||
chain_boundaries.append(len(sequence))
|
||||
assert len(chain_boundaries) % 2 == 0
|
||||
chain_boundaries = np.array(chain_boundaries).reshape(-1, 2)
|
||||
return chain_boundaries
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import replace
|
||||
from typing import TYPE_CHECKING
|
||||
from dataclasses import Field, replace
|
||||
from typing import Any, ClassVar, Protocol, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -10,15 +10,25 @@ from esm.utils.structure.protein_structure import (
|
||||
compute_affine_and_rmsd,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from esm.utils.structure.protein_chain import ProteinChain
|
||||
|
||||
class Alignable(Protocol):
|
||||
atom37_positions: np.ndarray
|
||||
atom37_mask: np.ndarray
|
||||
# Trick to detect whether an object is a dataclass
|
||||
__dataclass_fields__: ClassVar[dict[str, Field[Any]]]
|
||||
|
||||
def __len__(self) -> int:
|
||||
...
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Alignable)
|
||||
|
||||
|
||||
class Aligner:
|
||||
def __init__(
|
||||
self,
|
||||
mobile: ProteinChain,
|
||||
target: ProteinChain,
|
||||
mobile: Alignable,
|
||||
target: Alignable,
|
||||
only_use_backbone: bool = False,
|
||||
use_reflection: bool = False,
|
||||
):
|
||||
@@ -69,7 +79,7 @@ class Aligner:
|
||||
def rmsd(self):
|
||||
return self._rmsd
|
||||
|
||||
def apply(self, mobile: ProteinChain) -> ProteinChain:
|
||||
def apply(self, mobile: T) -> T:
|
||||
"""Apply alignment to a protein chain"""
|
||||
# Extract atom positions and convert to batched tensors
|
||||
mobile_atom_tensor = (
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from esm.utils import residue_constants as RC
|
||||
|
||||
|
||||
def compute_lddt(
|
||||
all_atom_pred_pos: torch.Tensor,
|
||||
all_atom_positions: torch.Tensor,
|
||||
all_atom_mask: torch.Tensor,
|
||||
cutoff: float = 15.0,
|
||||
eps: float = 1e-10,
|
||||
per_residue: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes LDDT for a protein. Tensor sizes below include some optional dimensions. Specifically:
|
||||
Nstates:
|
||||
all_atom_pred_pos can contain multiple states in the first dimension which corresponds to outputs from different layers of a model (e.g. each IPA block). The return size will be [Nstates x Batch size] if this is included.
|
||||
Natoms:
|
||||
LDDT can be computed for all atoms or some atoms. The second to last dimension should contain the *FLATTENED* representation of L x Natoms. If you want to calculate for atom37, e.g., this will be of size (L * 37). If you are only calculating CA LDDT, it will be of size L.
|
||||
|
||||
Args:
|
||||
all_atom_pred_pos (Tensor[float], [(Nstates x) B x (L * Natoms x) 3]): Tensor of predicted positions
|
||||
all_atom_positions (Tensor[float], [B x (L * Natoms x) 3]): Tensor of true positions
|
||||
all_atom_mask (Tensor[float], [B x (L * Natoms)]): Tensor of masks, indicating whether an atom exists.
|
||||
cutoff (float): Max distance to score lddt over.
|
||||
per_residue (bool): Whether to return per-residue or full-protein lddt.
|
||||
|
||||
Returns:
|
||||
LDDT Tensor:
|
||||
if per_residue:
|
||||
Tensor[float], [(Nstates x) B x (L * Natoms)]
|
||||
else:
|
||||
Tensor[float], [(Nstates x) B]
|
||||
"""
|
||||
n = all_atom_mask.shape[-2]
|
||||
dmat_true = torch.sqrt(
|
||||
eps
|
||||
+ torch.sum(
|
||||
(all_atom_positions[..., None, :] - all_atom_positions[..., None, :, :])
|
||||
** 2,
|
||||
dim=-1,
|
||||
)
|
||||
)
|
||||
|
||||
dmat_pred = torch.sqrt(
|
||||
eps
|
||||
+ torch.sum(
|
||||
(all_atom_pred_pos[..., None, :] - all_atom_pred_pos[..., None, :, :]) ** 2,
|
||||
dim=-1,
|
||||
)
|
||||
)
|
||||
dists_to_score = (
|
||||
(dmat_true < cutoff)
|
||||
* all_atom_mask
|
||||
* rearrange(all_atom_mask, "... a b -> ... b a")
|
||||
* (1.0 - torch.eye(n, device=all_atom_mask.device))
|
||||
)
|
||||
|
||||
dist_l1 = torch.abs(dmat_true - dmat_pred)
|
||||
|
||||
score = (
|
||||
(dist_l1 < 0.5).type(dist_l1.dtype)
|
||||
+ (dist_l1 < 1.0).type(dist_l1.dtype)
|
||||
+ (dist_l1 < 2.0).type(dist_l1.dtype)
|
||||
+ (dist_l1 < 4.0).type(dist_l1.dtype)
|
||||
)
|
||||
score = score * 0.25
|
||||
|
||||
dims = (-1,) if per_residue else (-2, -1)
|
||||
norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims))
|
||||
score = norm * (eps + torch.sum(dists_to_score * score, dim=dims))
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def compute_lddt_ca(
|
||||
all_atom_pred_pos: torch.Tensor,
|
||||
all_atom_positions: torch.Tensor,
|
||||
all_atom_mask: torch.Tensor,
|
||||
cutoff: float = 15.0,
|
||||
eps: float = 1e-10,
|
||||
per_residue: bool = True,
|
||||
) -> torch.Tensor:
|
||||
ca_pos = RC.atom_order["CA"]
|
||||
if all_atom_pred_pos.dim() != 3:
|
||||
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
|
||||
all_atom_positions = all_atom_positions[..., ca_pos, :]
|
||||
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
|
||||
|
||||
return compute_lddt(
|
||||
all_atom_pred_pos,
|
||||
all_atom_positions,
|
||||
all_atom_mask,
|
||||
cutoff=cutoff,
|
||||
eps=eps,
|
||||
per_residue=per_residue,
|
||||
)
|
||||
@@ -26,7 +26,7 @@ from esm.utils.constants import esm3 as C
|
||||
from esm.utils.misc import slice_python_object_as_numpy
|
||||
from esm.utils.structure.affine3d import Affine3D
|
||||
from esm.utils.structure.aligner import Aligner
|
||||
from esm.utils.structure.lddt import compute_lddt_ca
|
||||
from esm.utils.structure.metrics import compute_lddt_ca
|
||||
from esm.utils.structure.normalize_coordinates import (
|
||||
apply_frame_to_coords,
|
||||
get_protein_normalization_frame,
|
||||
@@ -542,7 +542,7 @@ class ProteinChain:
|
||||
id: str | None = None,
|
||||
is_predicted: bool = False,
|
||||
) -> "ProteinChain":
|
||||
"""Return a ProteinStructure object from an pdb file.
|
||||
"""Return a ProteinChain object from an pdb file.
|
||||
|
||||
Args:
|
||||
path (str | Path | io.TextIO): Path or buffer to read pdb file from. Should be uncompressed.
|
||||
@@ -644,6 +644,7 @@ class ProteinChain:
|
||||
pdb_id: str,
|
||||
chain_id: str = "detect",
|
||||
):
|
||||
"""Fetch a protein chain from the RCSB PDB database."""
|
||||
f: io.StringIO = rcsb.fetch(pdb_id, "pdb") # type: ignore
|
||||
return cls.from_pdb(f, chain_id=chain_id, id=pdb_id)
|
||||
|
||||
@@ -655,7 +656,9 @@ class ProteinChain:
|
||||
) -> "ProteinChain":
|
||||
"""A simple converter from bs.AtomArray -> ProteinChain.
|
||||
Uses PDB file format as intermediate."""
|
||||
pdb_file = bs.io.pdb.PDBFile() # pyright: ignore
|
||||
atom_array = atom_array.copy()
|
||||
atom_array.box = None # remove surrounding box, from_pdb won't handle this
|
||||
pdb_file = PDBFile() # pyright: ignore
|
||||
pdb_file.set_structure(atom_array)
|
||||
|
||||
buf = io.StringIO()
|
||||
@@ -784,7 +787,7 @@ class ProteinChain:
|
||||
sep_tokens = {
|
||||
"residue_index": np.array([-1]),
|
||||
"insertion_code": np.array([""]),
|
||||
"atom37_positions": np.full([1, 37, 3], np.inf),
|
||||
"atom37_positions": np.full([1, 37, 3], np.nan),
|
||||
"atom37_mask": np.zeros([1, 37]),
|
||||
"confidence": np.array([0]),
|
||||
}
|
||||
|
||||
@@ -261,3 +261,49 @@ def compute_affine_and_rmsd(
|
||||
)
|
||||
|
||||
return affine, avg_rmsd
|
||||
|
||||
|
||||
def compute_gdt_ts_no_alignment(
|
||||
aligned: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
atom_exists_mask: torch.Tensor,
|
||||
reduction: str = "batch",
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute GDT_TS between two batches of structures without alignment.
|
||||
|
||||
Args:
|
||||
- mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3)
|
||||
- target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3)
|
||||
- atom_exists_mask (torch.Tensor): Mask for Whether an atom exists of shape (B, N). noo
|
||||
- reduction (str): One of "batch", "per_sample".
|
||||
|
||||
Returns:
|
||||
If reduction == "batch":
|
||||
(torch.Tensor): 0-dim, GDT_TS between the structures for each batch
|
||||
If reduction == "per_sample":
|
||||
(torch.Tensor): (B,)-dim, GDT_TS between the structures for each sample in the batch
|
||||
"""
|
||||
if reduction not in ("per_sample", "batch"):
|
||||
raise ValueError("Unrecognized reduction: '{reduction}'")
|
||||
|
||||
if atom_exists_mask is None:
|
||||
atom_exists_mask = torch.isfinite(target).all(dim=-1)
|
||||
|
||||
deviation = torch.linalg.vector_norm(aligned - target, dim=-1)
|
||||
num_valid_atoms = atom_exists_mask.sum(dim=-1)
|
||||
|
||||
# Compute GDT_TS
|
||||
score = (
|
||||
((deviation < 1) * atom_exists_mask).sum(dim=-1) / num_valid_atoms
|
||||
+ ((deviation < 2) * atom_exists_mask).sum(dim=-1) / num_valid_atoms
|
||||
+ ((deviation < 4) * atom_exists_mask).sum(dim=-1) / num_valid_atoms
|
||||
+ ((deviation < 8) * atom_exists_mask).sum(dim=-1) / num_valid_atoms
|
||||
) * 0.25
|
||||
|
||||
if reduction == "batch":
|
||||
return score.mean()
|
||||
elif reduction == "per_sample":
|
||||
return score
|
||||
else:
|
||||
raise ValueError("Unrecognized reduction: '{reduction}'")
|
||||
|
||||
@@ -11,6 +11,7 @@ from esm.sdk.api import (
|
||||
SamplingTrackConfig,
|
||||
)
|
||||
from esm.utils.structure.protein_chain import ProteinChain
|
||||
from esm.utils.structure.protein_complex import ProteinComplex
|
||||
from esm.utils.types import FunctionAnnotation
|
||||
|
||||
|
||||
@@ -25,6 +26,12 @@ def get_sample_protein() -> ESMProtein:
|
||||
return protein
|
||||
|
||||
|
||||
def get_sample_protein_complex() -> ESMProtein:
|
||||
protein = ProteinComplex.from_rcsb("7a3w")
|
||||
protein = ESMProtein.from_protein_complex(protein)
|
||||
return protein
|
||||
|
||||
|
||||
def main(client: ESM3InferenceClient):
|
||||
# Single step decoding
|
||||
protein = get_sample_protein()
|
||||
@@ -124,6 +131,21 @@ def main(client: ESM3InferenceClient):
|
||||
), f"ESMProtein was expected but got {cot_protein}"
|
||||
cot_protein.to_pdb("./sample_cot.pdb")
|
||||
|
||||
# Protein Complex
|
||||
protein = get_sample_protein_complex()
|
||||
sequence_length = len(protein.sequence) # type: ignore
|
||||
num_steps = 1
|
||||
folded_protein = client.generate(
|
||||
protein,
|
||||
GenerationConfig(
|
||||
track="structure", schedule="cosine", num_steps=num_steps, temperature=0.0
|
||||
),
|
||||
)
|
||||
assert isinstance(
|
||||
folded_protein, ESMProtein
|
||||
), f"ESMProtein was expected but got {protein}"
|
||||
folded_protein.to_pdb("./sample_folded_complex.pdb")
|
||||
|
||||
# Batch examples.
|
||||
|
||||
# Batch generation.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "esm"
|
||||
version = "3.0.6"
|
||||
version = "3.0.7"
|
||||
description = "EvolutionaryScale open model repository"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
Reference in New Issue
Block a user