Add support for ProteinComplex (#125)

This commit is contained in:
Jenna MacCarley
2024-10-21 17:00:29 -04:00
committed by GitHub
parent a46ffdc2fc
commit 3fd11564af
10 changed files with 168 additions and 114 deletions

View File

@@ -1 +1 @@
__version__ = "3.0.6"
__version__ = "3.0.7"

View File

@@ -14,7 +14,11 @@ from esm.tokenization import (
)
from esm.utils import encoding
from esm.utils.constants.models import ESM3_OPEN_SMALL
from esm.utils.misc import (
get_chainbreak_boundaries_from_sequence,
)
from esm.utils.structure.protein_chain import ProteinChain
from esm.utils.structure.protein_complex import ProteinComplex
from esm.utils.types import (
FunctionAnnotation,
PathOrBuffer,
@@ -94,9 +98,27 @@ class ESMProtein(ProteinType):
coordinates=torch.tensor(protein_chain.atom37_positions),
)
@classmethod
def from_protein_complex(
cls, protein_complex: ProteinComplex, with_annotations: bool = False
) -> ESMProtein:
if with_annotations:
raise NotImplementedError(
"Annotations are not supported for ProteinComplex yet."
)
return ESMProtein(
sequence=protein_complex.sequence,
secondary_structure=None,
sasa=None,
function_annotations=None,
coordinates=torch.tensor(protein_complex.atom37_positions),
)
def to_pdb(self, pdb_path: PathOrBuffer) -> None:
protein_chain = self.to_protein_chain().infer_oxygen()
protein_chain.to_pdb(pdb_path)
# Note: Will work for single chains as well and produce same pdb file
protein_complex = self.to_protein_complex().infer_oxygen()
protein_complex.to_pdb(pdb_path)
def to_pdb_string(self) -> str:
protein_chain = self.to_protein_chain()
@@ -119,6 +141,33 @@ class ESMProtein(ProteinType):
)
return protein_chain
def to_protein_complex(
self, copy_annotations_from_ground_truth: ProteinComplex | None = None
) -> ProteinComplex:
assert (
self.sequence is not None
), "ESMProtein must have a sequence to convert to ProteinComplex"
assert (
self.coordinates is not None
), "ESMProtein must have coordinates to convert to ProteinComplex"
coords = self.coordinates.to("cpu").numpy()
chain_boundaries = get_chainbreak_boundaries_from_sequence(self.sequence)
if copy_annotations_from_ground_truth is not None:
gt_chains = list(copy_annotations_from_ground_truth.chain_iter())
else:
gt_chains = None
pred_chains = []
for i, (start, end) in enumerate(chain_boundaries):
pred_chain = ProteinChain.from_atom37(
atom37_positions=coords[start:end],
sequence=self.sequence[start:end],
chain_id=gt_chains[i].chain_id if gt_chains is not None else None,
entity_id=gt_chains[i].entity_id if gt_chains is not None else None,
)
pred_chains.append(pred_chain)
return ProteinComplex.from_chains(pred_chains)
@define
class ESMProteinTensor(ProteinType):

View File

@@ -414,7 +414,7 @@ def iterative_sampling_tokens(
# that of the prompt, which may or may not be padded, depending on
# whether the padding was done locally with the open source model
# (where per_prompt_cur_sampled is already padded) or by
# BatchedForwardRunner (where per_prompt_cur_sampled is not padded).
# BatchedESM3ModelRunner (where per_prompt_cur_sampled is not padded).
len(per_prompt_cur_sampled),
)

View File

@@ -2,11 +2,13 @@ import math
import os
from collections import defaultdict
from typing import ContextManager, Sequence, TypeVar
from warnings import warn
import huggingface_hub
import numpy as np
import torch
from esm.utils.constants.esm3 import CHAIN_BREAK_STR
from esm.utils.types import FunctionAnnotation
MAX_SUPPORTED_DISTANCE = 1e6
@@ -297,3 +299,23 @@ def huggingfacehub_login():
variable, else by prompting the user"""
token = os.environ.get("HF_TOKEN")
huggingface_hub.login(token=token)
def get_chainbreak_boundaries_from_sequence(sequence: Sequence[str]) -> np.ndarray:
chain_boundaries = [0]
for i, aa in enumerate(sequence):
if aa == CHAIN_BREAK_STR:
if i == (len(sequence) - 1):
raise ValueError(
"Encountered chain break token at end of sequence, this is unexpected."
)
if i == (len(sequence) - 2):
warn(
"Encountered chain break token at penultimate position, this is unexpected."
)
chain_boundaries.append(i)
chain_boundaries.append(i + 1)
chain_boundaries.append(len(sequence))
assert len(chain_boundaries) % 2 == 0
chain_boundaries = np.array(chain_boundaries).reshape(-1, 2)
return chain_boundaries

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from dataclasses import replace
from typing import TYPE_CHECKING
from dataclasses import Field, replace
from typing import Any, ClassVar, Protocol, TypeVar
import numpy as np
import torch
@@ -10,15 +10,25 @@ from esm.utils.structure.protein_structure import (
compute_affine_and_rmsd,
)
if TYPE_CHECKING:
from esm.utils.structure.protein_chain import ProteinChain
class Alignable(Protocol):
atom37_positions: np.ndarray
atom37_mask: np.ndarray
# Trick to detect whether an object is a dataclass
__dataclass_fields__: ClassVar[dict[str, Field[Any]]]
def __len__(self) -> int:
...
T = TypeVar("T", bound=Alignable)
class Aligner:
def __init__(
self,
mobile: ProteinChain,
target: ProteinChain,
mobile: Alignable,
target: Alignable,
only_use_backbone: bool = False,
use_reflection: bool = False,
):
@@ -69,7 +79,7 @@ class Aligner:
def rmsd(self):
return self._rmsd
def apply(self, mobile: ProteinChain) -> ProteinChain:
def apply(self, mobile: T) -> T:
"""Apply alignment to a protein chain"""
# Extract atom positions and convert to batched tensors
mobile_atom_tensor = (

View File

@@ -1,98 +0,0 @@
import torch
from einops import rearrange
from esm.utils import residue_constants as RC
def compute_lddt(
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
cutoff: float = 15.0,
eps: float = 1e-10,
per_residue: bool = True,
) -> torch.Tensor:
"""
Computes LDDT for a protein. Tensor sizes below include some optional dimensions. Specifically:
Nstates:
all_atom_pred_pos can contain multiple states in the first dimension which corresponds to outputs from different layers of a model (e.g. each IPA block). The return size will be [Nstates x Batch size] if this is included.
Natoms:
LDDT can be computed for all atoms or some atoms. The second to last dimension should contain the *FLATTENED* representation of L x Natoms. If you want to calculate for atom37, e.g., this will be of size (L * 37). If you are only calculating CA LDDT, it will be of size L.
Args:
all_atom_pred_pos (Tensor[float], [(Nstates x) B x (L * Natoms x) 3]): Tensor of predicted positions
all_atom_positions (Tensor[float], [B x (L * Natoms x) 3]): Tensor of true positions
all_atom_mask (Tensor[float], [B x (L * Natoms)]): Tensor of masks, indicating whether an atom exists.
cutoff (float): Max distance to score lddt over.
per_residue (bool): Whether to return per-residue or full-protein lddt.
Returns:
LDDT Tensor:
if per_residue:
Tensor[float], [(Nstates x) B x (L * Natoms)]
else:
Tensor[float], [(Nstates x) B]
"""
n = all_atom_mask.shape[-2]
dmat_true = torch.sqrt(
eps
+ torch.sum(
(all_atom_positions[..., None, :] - all_atom_positions[..., None, :, :])
** 2,
dim=-1,
)
)
dmat_pred = torch.sqrt(
eps
+ torch.sum(
(all_atom_pred_pos[..., None, :] - all_atom_pred_pos[..., None, :, :]) ** 2,
dim=-1,
)
)
dists_to_score = (
(dmat_true < cutoff)
* all_atom_mask
* rearrange(all_atom_mask, "... a b -> ... b a")
* (1.0 - torch.eye(n, device=all_atom_mask.device))
)
dist_l1 = torch.abs(dmat_true - dmat_pred)
score = (
(dist_l1 < 0.5).type(dist_l1.dtype)
+ (dist_l1 < 1.0).type(dist_l1.dtype)
+ (dist_l1 < 2.0).type(dist_l1.dtype)
+ (dist_l1 < 4.0).type(dist_l1.dtype)
)
score = score * 0.25
dims = (-1,) if per_residue else (-2, -1)
norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims))
score = norm * (eps + torch.sum(dists_to_score * score, dim=dims))
return score
def compute_lddt_ca(
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
cutoff: float = 15.0,
eps: float = 1e-10,
per_residue: bool = True,
) -> torch.Tensor:
ca_pos = RC.atom_order["CA"]
if all_atom_pred_pos.dim() != 3:
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
return compute_lddt(
all_atom_pred_pos,
all_atom_positions,
all_atom_mask,
cutoff=cutoff,
eps=eps,
per_residue=per_residue,
)

View File

@@ -26,7 +26,7 @@ from esm.utils.constants import esm3 as C
from esm.utils.misc import slice_python_object_as_numpy
from esm.utils.structure.affine3d import Affine3D
from esm.utils.structure.aligner import Aligner
from esm.utils.structure.lddt import compute_lddt_ca
from esm.utils.structure.metrics import compute_lddt_ca
from esm.utils.structure.normalize_coordinates import (
apply_frame_to_coords,
get_protein_normalization_frame,
@@ -542,7 +542,7 @@ class ProteinChain:
id: str | None = None,
is_predicted: bool = False,
) -> "ProteinChain":
"""Return a ProteinStructure object from an pdb file.
"""Return a ProteinChain object from an pdb file.
Args:
path (str | Path | io.TextIO): Path or buffer to read pdb file from. Should be uncompressed.
@@ -644,6 +644,7 @@ class ProteinChain:
pdb_id: str,
chain_id: str = "detect",
):
"""Fetch a protein chain from the RCSB PDB database."""
f: io.StringIO = rcsb.fetch(pdb_id, "pdb") # type: ignore
return cls.from_pdb(f, chain_id=chain_id, id=pdb_id)
@@ -655,7 +656,9 @@ class ProteinChain:
) -> "ProteinChain":
"""A simple converter from bs.AtomArray -> ProteinChain.
Uses PDB file format as intermediate."""
pdb_file = bs.io.pdb.PDBFile() # pyright: ignore
atom_array = atom_array.copy()
atom_array.box = None # remove surrounding box, from_pdb won't handle this
pdb_file = PDBFile() # pyright: ignore
pdb_file.set_structure(atom_array)
buf = io.StringIO()
@@ -784,7 +787,7 @@ class ProteinChain:
sep_tokens = {
"residue_index": np.array([-1]),
"insertion_code": np.array([""]),
"atom37_positions": np.full([1, 37, 3], np.inf),
"atom37_positions": np.full([1, 37, 3], np.nan),
"atom37_mask": np.zeros([1, 37]),
"confidence": np.array([0]),
}

View File

@@ -261,3 +261,49 @@ def compute_affine_and_rmsd(
)
return affine, avg_rmsd
def compute_gdt_ts_no_alignment(
aligned: torch.Tensor,
target: torch.Tensor,
atom_exists_mask: torch.Tensor,
reduction: str = "batch",
) -> torch.Tensor:
"""
Compute GDT_TS between two batches of structures without alignment.
Args:
- mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3)
- target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3)
- atom_exists_mask (torch.Tensor): Mask for Whether an atom exists of shape (B, N). noo
- reduction (str): One of "batch", "per_sample".
Returns:
If reduction == "batch":
(torch.Tensor): 0-dim, GDT_TS between the structures for each batch
If reduction == "per_sample":
(torch.Tensor): (B,)-dim, GDT_TS between the structures for each sample in the batch
"""
if reduction not in ("per_sample", "batch"):
raise ValueError("Unrecognized reduction: '{reduction}'")
if atom_exists_mask is None:
atom_exists_mask = torch.isfinite(target).all(dim=-1)
deviation = torch.linalg.vector_norm(aligned - target, dim=-1)
num_valid_atoms = atom_exists_mask.sum(dim=-1)
# Compute GDT_TS
score = (
((deviation < 1) * atom_exists_mask).sum(dim=-1) / num_valid_atoms
+ ((deviation < 2) * atom_exists_mask).sum(dim=-1) / num_valid_atoms
+ ((deviation < 4) * atom_exists_mask).sum(dim=-1) / num_valid_atoms
+ ((deviation < 8) * atom_exists_mask).sum(dim=-1) / num_valid_atoms
) * 0.25
if reduction == "batch":
return score.mean()
elif reduction == "per_sample":
return score
else:
raise ValueError("Unrecognized reduction: '{reduction}'")

View File

@@ -11,6 +11,7 @@ from esm.sdk.api import (
SamplingTrackConfig,
)
from esm.utils.structure.protein_chain import ProteinChain
from esm.utils.structure.protein_complex import ProteinComplex
from esm.utils.types import FunctionAnnotation
@@ -25,6 +26,12 @@ def get_sample_protein() -> ESMProtein:
return protein
def get_sample_protein_complex() -> ESMProtein:
protein = ProteinComplex.from_rcsb("7a3w")
protein = ESMProtein.from_protein_complex(protein)
return protein
def main(client: ESM3InferenceClient):
# Single step decoding
protein = get_sample_protein()
@@ -124,6 +131,21 @@ def main(client: ESM3InferenceClient):
), f"ESMProtein was expected but got {cot_protein}"
cot_protein.to_pdb("./sample_cot.pdb")
# Protein Complex
protein = get_sample_protein_complex()
sequence_length = len(protein.sequence) # type: ignore
num_steps = 1
folded_protein = client.generate(
protein,
GenerationConfig(
track="structure", schedule="cosine", num_steps=num_steps, temperature=0.0
),
)
assert isinstance(
folded_protein, ESMProtein
), f"ESMProtein was expected but got {protein}"
folded_protein.to_pdb("./sample_folded_complex.pdb")
# Batch examples.
# Batch generation.

View File

@@ -1,6 +1,6 @@
[project]
name = "esm"
version = "3.0.6"
version = "3.0.7"
description = "EvolutionaryScale open model repository"
readme = "README.md"
requires-python = ">=3.10"