3.2.2.post2 (#276)

This commit is contained in:
Neil Thomas
2025-09-19 15:28:14 -07:00
committed by GitHub
parent c8969198c6
commit 3e109e2d1b
24 changed files with 2199 additions and 55 deletions

View File

@@ -49,18 +49,18 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories."
"Grab a token from [Forge](https://forge.evolutionaryscale.ai/) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories."
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from getpass import getpass\n",
"\n",
"token = getpass(\"Token from Forge console: \")"
"token = getpass(\"Token from Forge: \")"
]
},
{

View File

@@ -80,18 +80,18 @@
"\n",
"The largest ESM3 (98 billion parameters) was trained with 1.07e24 FLOPs on 2.78 billion proteins and 771 billion unique tokens. To create esmGFP we used the 7 billion parameter variant of ESM3. We'll use this model via the [EvolutionaryScale Forge](https://forge.evolutionaryscale.ai) API.\n",
"\n",
"Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories.\n"
"Grab a token from [Forge](https://forge.evolutionaryscale.ai/) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories.\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {
"id": "zNrU9Q2SYonX"
},
"outputs": [],
"source": [
"token = getpass(\"Token from Forge console: \")"
"token = getpass(\"Token from Forge: \")"
]
},
{

View File

@@ -53,7 +53,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Grab a token from [the Forge console](https://forge.evolutionaryscale.ai/console) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories."
"Grab a token from [Forge](https://forge.evolutionaryscale.ai/) and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like `getpass` as shown below so tokens are not accidentally shared or checked into code repositories."
]
},
{
@@ -64,7 +64,7 @@
"source": [
"from getpass import getpass\n",
"\n",
"token = getpass(\"Token from Forge console: \")\n",
"token = getpass(\"Token from Forge: \")\n",
"model = client(model=\"esm3-open\", url=\"https://forge.evolutionaryscale.ai\", token=token)"
]
},

View File

@@ -120,7 +120,7 @@
"\n",
"from esm.sdk import client\n",
"\n",
"token = getpass(\"Token from Forge console: \")\n",
"token = getpass(\"Token from Forge: \")\n",
"model = client(\n",
" model=\"esm3-medium-2024-08\", url=\"https://forge.evolutionaryscale.ai\", token=token\n",
")"

View File

@@ -1 +1 @@
__version__ = "3.2.2"
__version__ = "3.2.2.post2"

View File

@@ -148,12 +148,35 @@ class ESMProtein(ProteinType):
gt_chains = list(copy_annotations_from_ground_truth.chain_iter())
else:
gt_chains = None
# Expand pLDDT to match sequence length if needed, inserting NaN at chain breaks
# This handles the case where the server doesn't include chain breaks in pLDDT
# We should fix this in the server side.
if self.plddt is not None and len(self.plddt) != len(self.sequence):
# Only expand if there's a mismatch (likely due to chain breaks)
if "|" in self.sequence:
# Create expanded pLDDT with NaN at chain break positions
expanded_plddt = torch.full((len(self.sequence),), float("nan"))
plddt_idx = 0
for i, aa in enumerate(self.sequence):
if aa != "|":
if plddt_idx < len(self.plddt):
expanded_plddt[i] = self.plddt[plddt_idx]
plddt_idx += 1
plddt = expanded_plddt
else:
# Mismatch but no chain breaks - shouldn't happen but preserve original
plddt = self.plddt
else:
plddt = self.plddt
pred_chains = []
for i, (start, end) in enumerate(chain_boundaries):
if i >= len(SINGLE_LETTER_CHAIN_IDS):
raise ValueError(
f"Too many chains to convert to ProteinComplex. The maximum number of chains is {len(SINGLE_LETTER_CHAIN_IDS)}"
)
pred_chain = ProteinChain.from_atom37(
atom37_positions=coords[start:end],
sequence=self.sequence[start:end],
@@ -161,7 +184,7 @@ class ESMProtein(ProteinType):
if gt_chains is not None
else SINGLE_LETTER_CHAIN_IDS[i],
entity_id=gt_chains[i].entity_id if gt_chains is not None else None,
confidence=self.plddt[start:end] if self.plddt is not None else None,
confidence=plddt[start:end] if plddt is not None else None,
)
pred_chains.append(pred_chain)
return ProteinComplex.from_chains(pred_chains)
@@ -298,13 +321,6 @@ class GenerationConfig:
self.temperature_annealing = True
@define
class MSA:
# Paired MSA sequences.
# One would typically compute these using, for example, ColabFold.
sequences: list[str]
@define
class InverseFoldingConfig:
invalid_ids: Sequence[int] = []

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import asyncio
import base64
import pickle
@@ -7,7 +9,6 @@ from typing import Any, Sequence
import torch
from esm.sdk.api import (
MSA,
ESM3InferenceClient,
ESMCInferenceClient,
ESMProtein,
@@ -27,6 +28,15 @@ from esm.sdk.base_forge_client import _BaseForgeInferenceClient
from esm.sdk.retry import retry_decorator
from esm.utils.constants.api import MIMETYPE_ES_PICKLE
from esm.utils.misc import deserialize_tensors, maybe_list, maybe_tensor
from esm.utils.msa import MSA
from esm.utils.structure.input_builder import (
StructurePredictionInput,
serialize_structure_prediction_input,
)
from esm.utils.structure.molecular_complex import (
MolecularComplex,
MolecularComplexResult,
)
from esm.utils.types import FunctionAnnotation
@@ -36,10 +46,8 @@ def _list_to_function_annotations(l) -> list[FunctionAnnotation] | None:
return [FunctionAnnotation(*t) for t in l]
def _maybe_logits(data: dict[str, Any], track: str, return_bytes: bool = False):
ret = data.get("logits", {}).get(track, None)
# TODO(s22chan): just return this when removing return_bytes
return ret if ret is None or not return_bytes else maybe_tensor(ret)
def _maybe_logits(data: dict[str, Any], track: str):
return maybe_tensor(data.get("logits", {}).get(track, None))
def _maybe_b64_decode(obj, return_bytes: bool):
@@ -137,7 +145,7 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
data = await self._async_post(
"msa", request={}, params={"sequence": sequence, "use_env": False}
)
return MSA(sequences=data["msa"])
return MSA.from_sequences(sequences=data["msa"])
def _fetch_msa(self, sequence: str) -> MSA:
print("Fetching MSA ... this may take a few minutes")
@@ -146,7 +154,7 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
data = self._post(
"msa", request={}, params={"sequence": sequence, "use_env": False}
)
return MSA(sequences=data["msa"])
return MSA.from_sequences(sequences=data["msa"])
@retry_decorator
async def async_fold(
@@ -209,6 +217,70 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
return self._process_fold_response(data, sequence)
@retry_decorator
async def async_fold_all_atom(
self, all_atom_input: StructurePredictionInput, model_name: str | None = None
) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError:
"""Fold a molecular complex containing proteins, nucleic acids, and/or ligands.
Args:
all_atom_input: StructurePredictionInput containing sequences for different molecule types
model_name: Override the client level model name if needed
"""
request = self._process_fold_all_atom_request(
all_atom_input, model_name if model_name is not None else self.model
)
try:
data = await self._async_post("fold_all_atom", request)
except ESMProteinError as e:
return e
return self._process_fold_all_atom_response(data)
@retry_decorator
def fold_all_atom(
self, all_atom_input: StructurePredictionInput, model_name: str | None = None
) -> MolecularComplexResult | list[MolecularComplexResult] | ESMProteinError:
"""Predict coordinates for a molecular complex containing proteins, dna, rna, and/or ligands.
Args:
all_atom_input: StructurePredictionInput containing sequences for different molecule types
model_name: Override the client level model name if needed
"""
request = self._process_fold_all_atom_request(
all_atom_input, model_name if model_name is not None else self.model
)
try:
data = self._post("fold_all_atom", request)
except ESMProteinError as e:
return e
return self._process_fold_all_atom_response(data)
@staticmethod
def _process_fold_all_atom_request(
all_atom_input: StructurePredictionInput, model_name: str | None = None
) -> dict[str, Any]:
request: dict[str, Any] = {
"all_atom_input": serialize_structure_prediction_input(all_atom_input),
"model": model_name,
}
return request
@staticmethod
def _process_fold_all_atom_response(data: dict[str, Any]) -> MolecularComplexResult:
complex_data = data.get("complex")
molecular_complex = MolecularComplex.from_state_dict(complex_data)
return MolecularComplexResult(
complex=molecular_complex,
plddt=maybe_tensor(data.get("plddt"), convert_none_to_nan=True),
ptm=data.get("ptm", None),
distogram=maybe_tensor(data.get("distogram"), convert_none_to_nan=True),
)
@retry_decorator
async def async_inverse_fold(
self,
@@ -602,19 +674,15 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient, _BaseForgeInferenceClient):
return LogitsOutput(
logits=ForwardTrackData(
sequence=_maybe_logits(data, "sequence", return_bytes),
structure=_maybe_logits(data, "structure", return_bytes),
secondary_structure=_maybe_logits(
data, "secondary_structure", return_bytes
),
sasa=_maybe_logits(data, "sasa", return_bytes),
function=_maybe_logits(data, "function", return_bytes),
sequence=_maybe_logits(data, "sequence"),
structure=_maybe_logits(data, "structure"),
secondary_structure=_maybe_logits(data, "secondary_structure"),
sasa=_maybe_logits(data, "sasa"),
function=_maybe_logits(data, "function"),
),
embeddings=maybe_tensor(data["embeddings"]),
mean_embedding=data["mean_embedding"],
residue_annotation_logits=_maybe_logits(
data, "residue_annotation", return_bytes
),
residue_annotation_logits=_maybe_logits(data, "residue_annotation"),
hidden_states=maybe_tensor(data["hidden_states"]),
mean_hidden_state=maybe_tensor(data["mean_hidden_state"]),
)
@@ -965,6 +1033,7 @@ class ESMCForgeInferenceClient(ESMCInferenceClient, _BaseForgeInferenceClient):
"sequence": config.sequence,
"return_embeddings": config.return_embeddings,
"return_mean_embedding": config.return_mean_embedding,
"return_mean_hidden_states": config.return_mean_hidden_states,
"return_hidden_states": config.return_hidden_states,
"ith_hidden_layer": config.ith_hidden_layer,
}
@@ -981,12 +1050,11 @@ class ESMCForgeInferenceClient(ESMCInferenceClient, _BaseForgeInferenceClient):
data["hidden_states"] = _maybe_b64_decode(data["hidden_states"], return_bytes)
output = LogitsOutput(
logits=ForwardTrackData(
sequence=_maybe_logits(data, "sequence", return_bytes)
),
logits=ForwardTrackData(sequence=_maybe_logits(data, "sequence")),
embeddings=maybe_tensor(data["embeddings"]),
mean_embedding=data["mean_embedding"],
hidden_states=maybe_tensor(data["hidden_states"]),
mean_hidden_state=maybe_tensor(data["mean_hidden_state"]),
)
return output

View File

@@ -2,10 +2,9 @@ import inspect
from contextvars import ContextVar
from functools import wraps
import httpx
from tenacity import (
retry,
retry_if_exception_type,
retry_if_exception,
retry_if_result,
stop_after_attempt,
wait_incrementing,
@@ -30,8 +29,12 @@ def retry_if_specific_error(exception):
def log_retry_attempt(retry_state):
try:
outcome = retry_state.outcome.result()
except Exception:
outcome = retry_state.outcome.exception()
print(
f"Retrying... Attempt {retry_state.attempt_number} after {retry_state.next_action.sleep}s due to: {retry_state.outcome.result()}"
f"Retrying... Attempt {retry_state.attempt_number} after {retry_state.next_action.sleep}s due to: {outcome}"
)
@@ -41,13 +44,18 @@ def retry_decorator(func):
instance's retry settings.
"""
def return_last_value(retry_state):
"""Return the result of the last call attempt."""
return retry_state.outcome.result()
@wraps(func)
async def async_wrapper(instance, *args, **kwargs):
if skip_retries_var.get():
return await func(instance, *args, **kwargs)
retry_decorator = retry(
retry_error_callback=return_last_value,
retry=retry_if_result(retry_if_specific_error)
| retry_if_exception_type(httpx.ConnectTimeout), # ADDED
| retry_if_exception(retry_if_specific_error),
wait=wait_incrementing(
increment=1, start=instance.min_retry_wait, max=instance.max_retry_wait
),
@@ -62,8 +70,9 @@ def retry_decorator(func):
if skip_retries_var.get():
return func(instance, *args, **kwargs)
retry_decorator = retry(
retry_error_callback=return_last_value,
retry=retry_if_result(retry_if_specific_error)
| retry_if_exception_type(httpx.ConnectTimeout), # ADDED
| retry_if_exception(retry_if_specific_error),
wait=wait_incrementing(
increment=1, start=instance.min_retry_wait, max=instance.max_retry_wait
),

View File

@@ -43,7 +43,9 @@ def _trim_sequence_tensor_dataclass(o: Any, sequence_len: int):
sliced = {}
for k, v in attr.asdict(o, recurse=False).items():
if v is None:
if k in ["mean_hidden_state", "mean_embedding"]:
sliced[k] = v
elif v is None:
sliced[k] = None
elif isinstance(v, torch.Tensor):
# Trim padding.

View File

@@ -1,8 +1,20 @@
from __future__ import annotations
import os
from collections import defaultdict
from contextlib import nullcontext
from dataclasses import is_dataclass
from io import BytesIO
from typing import Any, ContextManager, Sequence, TypeVar
from typing import (
Any,
ContextManager,
Generator,
Iterable,
Protocol,
Sequence,
TypeVar,
runtime_checkable,
)
from warnings import warn
import huggingface_hub
@@ -19,6 +31,12 @@ MAX_SUPPORTED_DISTANCE = 1e6
TSequence = TypeVar("TSequence", bound=Sequence)
@runtime_checkable
class Concatable(Protocol):
@classmethod
def concat(cls, objs: list[Concatable]) -> Concatable: ...
def slice_python_object_as_numpy(
obj: TSequence, idx: int | list[int] | slice | np.ndarray
) -> TSequence:
@@ -53,6 +71,37 @@ def slice_python_object_as_numpy(
return sliced_obj # type: ignore
def slice_any_object(
obj: TSequence, idx: int | list[int] | slice | np.ndarray
) -> TSequence:
"""
Slice a arbitrary object (like a list, string, or tuple) as if it was a numpy object. Similar to `slice_python_object_as_numpy`, but detects if it's a numpy array or Tensor and uses the existing slice method if so.
If the object is a dataclass, it will simply apply the index to the object, under the assumption that the object has correcty implemented numpy indexing.
Example:
>>> obj = "ABCDE"
>>> slice_any_object(obj, [1, 3, 4])
"BDE"
>>> obj = np.array([1, 2, 3, 4, 5])
>>> slice_any_object(obj, np.arange(5) < 3)
np.array([1, 2, 3])
>>> obj = ProteinChain.from_rcsb("1a3a", "A")
>>> slice_any_object(obj, np.arange(len(obj)) < 10)
# ProteinChain w/ length 10
"""
if isinstance(obj, (np.ndarray, torch.Tensor)):
return obj[idx] # type: ignore
elif is_dataclass(obj):
# if passing a dataclass, assume it implements a custom slice
return obj[idx] # type: ignore
else:
return slice_python_object_as_numpy(obj, idx)
def rbf(values, v_min, v_max, n_bins=16):
"""
Returns RBF encodings in a new dimension at the end.
@@ -213,7 +262,7 @@ def unbinpack(
return stack_variable_length_tensors(unpacked_tensors, pad_value)
def fp32_autocast_context(device_type: str) -> ContextManager[torch.amp.autocast]: # type: ignore
def fp32_autocast_context(device_type: str) -> ContextManager[Any]: # type: ignore
"""
Returns an autocast context manager that disables downcasting by AMP.
@@ -302,6 +351,8 @@ def replace_inf(data):
def maybe_tensor(x, convert_none_to_nan: bool = False) -> torch.Tensor | None:
if x is None:
return None
if isinstance(x, torch.Tensor):
return x
if isinstance(x, list) and all(isinstance(t, torch.Tensor) for t in x):
return torch.stack(x)
if convert_none_to_nan:
@@ -361,3 +412,90 @@ def deserialize_tensors(b: bytes) -> Any:
buf = BytesIO(zstd.ZSTD_uncompress(b))
d = torch.load(buf, map_location="cpu", weights_only=False)
return d
def join_lists(
lists: Sequence[Sequence[Any]], separator: Sequence[Any] | None = None
) -> list[Any]:
"""Joins multiple lists with separator element. Like str.join but for lists.
Example: [[1, 2], [3], [4]], separator=[0] -> [1, 2, 0, 3, 0, 4]
Args:
lists: Lists of elements to chain
separator: separators to intsert between chained output.
Returns:
Joined lists.
"""
if not lists:
return []
joined = []
joined.extend(lists[0])
for l in lists[1:]:
if separator:
joined.extend(separator)
joined.extend(l)
return joined
def iterate_with_intermediate(
lists: Iterable, intermediate
) -> Generator[Any, None, None]:
"""
Iterate over the iterable, yielding the intermediate value between
every element of the intermediate. Useful for joining objects with
separator tokens.
"""
it = iter(lists)
yield next(it)
for l in it:
yield intermediate
yield l
def concat_objects(objs: Sequence[Any], separator: Any | None = None):
"""
Concat objects with each other using a separator token.
Supports:
- Concatable (objects that implement `concat` classmethod)
- strings
- lists
- numpy arrays
- torch Tensors
Example:
>>> foo = "abc"
>>> bar = "def"
>>> concat_objects([foo, bar], "|")
"abc|def"
"""
match objs[0]:
case Concatable():
return objs[0].__class__.concat(objs) # type: ignore
case str():
assert isinstance(
separator, str
), "Trying to join strings but separator is not a string"
return separator.join(objs)
case list():
if separator is not None:
return join_lists(objs, [separator])
else:
return join_lists(objs)
case np.ndarray():
if separator is not None:
return np.concatenate(
list(iterate_with_intermediate(objs, np.array([separator])))
)
else:
return np.concatenate(objs)
case torch.Tensor():
if separator is not None:
return torch.cat(
list(iterate_with_intermediate(objs, torch.tensor([separator])))
)
else:
return torch.cat(objs) # type: ignore
case _:
raise TypeError(type(objs[0]))

View File

@@ -0,0 +1,3 @@
from esm.utils.msa.msa import MSA, FastMSA, remove_insertions_from_sequence
__all__ = ["MSA", "FastMSA", "remove_insertions_from_sequence"]

View File

@@ -0,0 +1,79 @@
import tempfile
from pathlib import Path
import numpy as np
from scipy.spatial.distance import cdist
from esm.utils.system import run_subprocess_with_errorcheck
def greedy_select_indices(array, num_seqs: int, mode: str = "max") -> list[int]:
"""Greedily select sequences that either maximize or minimize hamming distance.
Algorithm proposed in the MSA Transformer paper. Starting from the query sequence,
iteratively add sequences to the list with the maximum (minimum) average Hamming
distance to the existing set of sequences.
Args:
array (np.ndarray): Character array representing the sequences in the MSA
num_seqs (int): Number of sequences to select.
mode (str): Whether to maximize or minimize diversity. DO NOT pick 'min' unless
you're doing it to prove a point for a paper.
Returns:
list[int]: List of indices to select from the array
"""
assert mode in ("max", "min")
depth = array.shape[0]
if depth <= num_seqs:
return list(range(depth))
array = array.view(np.uint8)
optfunc = np.argmax if mode == "max" else np.argmin
all_indices = np.arange(depth)
indices = [0]
pairwise_distances = np.zeros((0, depth))
for _ in range(num_seqs - 1):
dist = cdist(array[indices[-1:]], array, "hamming")
pairwise_distances = np.concatenate([pairwise_distances, dist])
shifted_distance = np.delete(pairwise_distances, indices, axis=1).mean(0)
shifted_index = optfunc(shifted_distance)
index = np.delete(all_indices, indices)[shifted_index]
indices.append(index)
indices = sorted(indices)
return indices
def hhfilter(
sequences: list[str],
seqid: int = 90,
diff: int = 0,
cov: int = 0,
qid: int = 0,
qsc: float = -20.0,
binary: str = "hhfilter",
) -> list[int]:
with tempfile.TemporaryDirectory(dir="/dev/shm") as tempdirname:
tempdir = Path(tempdirname)
fasta_file = tempdir / "input.fasta"
fasta_file.write_text(
"\n".join(f">{i}\n{seq}" for i, seq in enumerate(sequences))
)
output_file = tempdir / "output.fasta"
command = " ".join(
[
f"{binary}",
f"-i {fasta_file}",
"-M a3m",
f"-o {output_file}",
f"-id {seqid}",
f"-diff {diff}",
f"-cov {cov}",
f"-qid {qid}",
f"-qsc {qsc}",
]
).split(" ")
run_subprocess_with_errorcheck(command, capture_output=True)
with output_file.open() as f:
indices = [int(line[1:].strip()) for line in f if line.startswith(">")]
return indices

500
esm/utils/msa/msa.py Normal file
View File

@@ -0,0 +1,500 @@
from __future__ import annotations
import dataclasses
import string
from dataclasses import dataclass
from functools import cached_property
from itertools import islice
from typing import Sequence
import numpy as np
from Bio import SeqIO
from scipy.spatial.distance import cdist
from esm.utils.misc import slice_any_object
from esm.utils.msa.filter_sequences import greedy_select_indices, hhfilter
from esm.utils.parsing import FastaEntry, read_sequences, write_sequences
from esm.utils.sequential_dataclass import SequentialDataclass
from esm.utils.system import PathOrBuffer
REMOVE_LOWERCASE_TRANSLATION = str.maketrans(dict.fromkeys(string.ascii_lowercase))
def remove_insertions_from_sequence(seq: str) -> str:
return seq.translate(REMOVE_LOWERCASE_TRANSLATION)
@dataclass(frozen=True)
class MSA(SequentialDataclass):
"""Object-oriented interface to an MSA.
Args:
sequences (list[str]): List of protein sequences
headers (list[str]): List of headers describing the sequences
"""
entries: list[FastaEntry]
@cached_property
def sequences(self) -> list[str]:
return [entry.sequence for entry in self.entries]
@cached_property
def headers(self) -> list[str]:
return [entry.header for entry in self.entries]
def __repr__(self):
return (
f"MSA({self.entries[0].header}: Depth={self.depth}, Length={self.seqlen})"
)
def to_fast_msa(self) -> FastMSA:
return FastMSA(self.array, self.headers)
@classmethod
def from_a3m(
cls,
path: PathOrBuffer,
remove_insertions: bool = True,
max_sequences: int | None = None,
) -> MSA:
entries = []
for header, seq in islice(read_sequences(path), max_sequences):
if remove_insertions:
seq = remove_insertions_from_sequence(seq)
if entries:
assert (
len(seq) == len(entries[0].sequence)
), f"Sequence length mismatch. Expected: {len(entries[0].sequence)}, Received: {len(seq)}"
entries.append(FastaEntry(header, seq))
return cls(entries)
def to_a3m(self, path: PathOrBuffer) -> None:
write_sequences(self.entries, path)
@classmethod
def from_stockholm(
cls,
path: PathOrBuffer,
remove_insertions: bool = True,
max_sequences: int | None = None,
) -> MSA:
entries = []
for record in islice(SeqIO.parse(path, "stockholm"), max_sequences):
header = f"{record.id} {record.description}"
seq = str(record.seq)
if entries:
assert (
len(seq) == len(entries[0].sequence)
), f"Sequence length mismatch. Expected: {len(entries[0].sequence)}, Received: {len(seq)}"
entries.append(FastaEntry(header, seq))
msa = cls(entries)
if remove_insertions:
keep_inds = [i for i, aa in enumerate(msa.query) if aa != "-"]
msa = msa.select_positions(keep_inds)
return msa
def to_bytes(self) -> bytes:
version = 1
version_bytes = version.to_bytes(1, "little")
seqlen_bytes = self.seqlen.to_bytes(4, "little")
depth_bytes = self.depth.to_bytes(4, "little")
array_bytes = self.array.tobytes()
header_bytes = "\n".join(entry.header for entry in self.entries).encode()
all_bytes = (
version_bytes + seqlen_bytes + depth_bytes + array_bytes + header_bytes
)
return all_bytes
@classmethod
def from_bytes(cls, data: bytes) -> MSA:
version_bytes, seqlen_bytes, depth_bytes, data = (
data[:1],
data[1:5],
data[5:9],
data[9:],
)
version = int.from_bytes(version_bytes, "little")
if version != 1:
raise ValueError(f"Unsupported version: {version}")
seqlen = int.from_bytes(seqlen_bytes, "little")
depth = int.from_bytes(depth_bytes, "little")
array_bytes, header_bytes = data[: seqlen * depth], data[seqlen * depth :]
array = np.frombuffer(array_bytes, dtype="|S1")
array = array.reshape(depth, seqlen)
headers = header_bytes.decode().split("\n")
# Sometimes the separation is two newlines, which results in an empty header.
headers = [header for header in headers if header]
entries = [
FastaEntry(header, b"".join(row).decode())
for header, row in zip(headers, array)
]
return cls(entries)
# TODO(jmaccarl): set remove_insertions to True by default here to match other utils
@classmethod
def from_sequences(
cls, sequences: list[str], remove_insertions: bool = False
) -> MSA:
if remove_insertions:
entries = [
FastaEntry("", remove_insertions_from_sequence(seq))
for seq in sequences
]
else:
entries = [FastaEntry("", seq) for seq in sequences]
return cls(entries)
def to_sequence_bytes(self) -> bytes:
"""Stores ONLY SEQUENCES in array format as bytes. Header information will be lost."""
seqlen_bytes = self.seqlen.to_bytes(4, "little")
array_bytes = self.array.tobytes()
all_bytes = seqlen_bytes + array_bytes
return all_bytes
@classmethod
def from_sequence_bytes(cls, data: bytes) -> MSA:
seqlen_bytes, array_bytes = data[:4], data[4:]
seqlen = int.from_bytes(seqlen_bytes, "little")
array = np.frombuffer(array_bytes, dtype="|S1")
array = array.reshape(-1, seqlen)
entries = [FastaEntry("", b"".join(row).decode()) for row in array]
return cls(entries)
@property
def depth(self) -> int:
return len(self.entries)
@property
def seqlen(self) -> int:
return len(self.entries[0].sequence)
@cached_property
def array(self) -> np.ndarray:
return np.array([list(seq) for seq in self.sequences], dtype="|S1")
@property
def query(self) -> str:
return self.entries[0].sequence
def select_sequences(self, indices: Sequence[int] | np.ndarray) -> MSA:
"""Subselect rows of the MSA."""
entries = [self.entries[idx] for idx in indices]
return dataclasses.replace(self, entries=entries)
def select_positions(self, indices: Sequence[int] | np.ndarray) -> MSA:
"""Subselect columns of the MSA."""
entries = [
FastaEntry(header, "".join(seq[idx] for idx in indices))
for header, seq in self.entries
]
return dataclasses.replace(self, entries=entries)
def __getitem__(self, indices: int | list[int] | slice | np.ndarray):
if isinstance(indices, int):
indices = [indices]
entries = [
FastaEntry(header, slice_any_object(seq, indices))
for header, seq in self.entries
]
return dataclasses.replace(self, entries=entries)
def __len__(self):
return self.seqlen
def greedy_select(self, num_seqs: int, mode: str = "max") -> MSA:
"""Greedily select sequences that either maximize or minimize hamming distance.
Algorithm proposed in the MSA Transformer paper. Starting from the query sequence,
iteratively add sequences to the list with the maximum (minimum) average Hamming
distance to the existing set of sequences.
Args:
num_seqs (int): Number of sequences to select.
mode (str): Whether to maximize or minimize diversity. DO NOT pick 'min' unless
you're doing it to prove a point for a paper.
Returns:
MSA object w/ subselected sequences.
"""
assert mode in ("max", "min")
if self.depth <= num_seqs:
return self
indices = greedy_select_indices(self.array, num_seqs, mode)
return self.select_sequences(indices)
def hhfilter(
self,
seqid: int = 90,
diff: int = 0,
cov: int = 0,
qid: int = 0,
qsc: float = -20.0,
binary: str = "hhfilter",
) -> MSA:
"""Apply hhfilter to the sequences in the MSA and return a filtered MSA."""
indices = hhfilter(
self.sequences,
seqid=seqid,
diff=diff,
cov=cov,
qid=qid,
qsc=qsc,
binary=binary,
)
return self.select_sequences(indices)
def select_random_sequences(self, num_seqs: int) -> MSA:
"""Uses random sampling to subselect sequences from the MSA. Always
keeps the query sequence.
"""
if num_seqs >= self.depth:
return self
# Subselect random, always keeping the query sequence.
indices = np.sort(
np.append(
0, np.random.choice(self.depth - 1, num_seqs - 1, replace=False) + 1
)
)
msa = self.select_sequences(indices) # type: ignore
return msa
def select_diverse_sequences(self, num_seqs: int) -> MSA:
"""Applies hhfilter to select ~num_seqs sequences, then uses random sampling
to subselect if necessary.
"""
if num_seqs >= self.depth:
return self
msa = self.hhfilter(diff=num_seqs)
if num_seqs < msa.depth:
msa = msa.select_random_sequences(num_seqs)
return msa
def pad_to_depth(self, depth: int) -> MSA:
if depth < self.depth:
raise ValueError(f"Cannot pad to depth {depth} when depth is {self.depth}")
elif depth == self.depth:
return self
num_to_add = depth - self.depth
extra_entries = [FastaEntry("", "-" * self.seqlen) for _ in range(num_to_add)]
return dataclasses.replace(self, entries=self.entries + extra_entries)
@classmethod
def stack(
cls, msas: Sequence[MSA], remove_query_from_later_msas: bool = True
) -> MSA:
"""Stack a series of MSAs. Optionally remove the query from msas after the first."""
all_entries = []
for i, msa in enumerate(msas):
entries = msa.entries
if i > 0 and remove_query_from_later_msas:
entries = entries[1:]
all_entries.extend(entries)
return cls(entries=all_entries)
@cached_property
def seqid(self) -> np.ndarray:
array = self.array.view(np.uint8)
seqid = 1 - cdist(array[0][None], array, "hamming")
return seqid[0]
@classmethod
def concat(
cls,
msas: Sequence[MSA],
join_token: str | None = "|",
allow_depth_mismatch: bool = False,
) -> MSA:
"""Concatenate a series of MSAs horizontally, along the sequence dimension."""
if not msas:
raise ValueError("Cannot concatenate an empty list of MSAs")
msa_depths = [msa.depth for msa in msas]
if len(set(msa_depths)) != 1:
if not allow_depth_mismatch:
raise ValueError("Depth mismatch in concatenating MSAs")
else:
max_depth = max(msa_depths)
msas = [msa.pad_to_depth(max_depth) for msa in msas]
headers = [
"|".join([str(h) for h in headers])
for headers in zip(*(msa.headers for msa in msas))
]
if join_token is None:
join_token = ""
seqs = [join_token.join(vals) for vals in zip(*(msa.sequences for msa in msas))]
entries = [FastaEntry(header, seq) for header, seq in zip(headers, seqs)]
return cls(entries)
@dataclass(frozen=True)
class FastMSA(SequentialDataclass):
"""Object-oriented interface to an MSA stored as a numpy uint8 array."""
array: np.ndarray
headers: list[str] | None = None
def __post_init__(self):
if self.headers is not None:
assert (
len(self.headers) == self.depth
), "Number of headers must match depth."
@classmethod
def from_bytes(cls, data: bytes) -> FastMSA:
version_bytes, seqlen_bytes, depth_bytes, data = (
data[:1],
data[1:5],
data[5:9],
data[9:],
)
version = int.from_bytes(version_bytes, "little")
if version != 1:
raise ValueError(f"Unsupported version: {version}")
seqlen = int.from_bytes(seqlen_bytes, "little")
depth = int.from_bytes(depth_bytes, "little")
array_bytes, header_bytes = data[: seqlen * depth], data[seqlen * depth :]
array = np.frombuffer(array_bytes, dtype="|S1")
array = array.reshape(depth, seqlen)
headers = header_bytes.decode().split("\n")
# Sometimes the separation is two newlines, which results in an empty header.
headers = [header for header in headers if header]
return cls(array, headers)
@classmethod
def from_sequence_bytes(cls, data: bytes) -> FastMSA:
seqlen_bytes, array_bytes = data[:4], data[4:]
seqlen = int.from_bytes(seqlen_bytes, "little")
array = np.frombuffer(array_bytes, dtype="|S1")
array = array.reshape(-1, seqlen)
return cls(array)
@property
def depth(self) -> int:
return self.array.shape[0]
@property
def seqlen(self) -> int:
return self.array.shape[1]
def __len__(self):
return self.seqlen
def __getitem__(self, indices: int | list[int] | slice | np.ndarray):
if isinstance(indices, int):
indices = [indices]
return dataclasses.replace(self, array=self.array[:, indices])
def select_sequences(self, indices: Sequence[int] | np.ndarray) -> FastMSA:
"""Subselect rows of the MSA."""
array = self.array[indices]
headers = (
[self.headers[idx] for idx in indices] if self.headers is not None else None
)
return dataclasses.replace(self, array=array, headers=headers)
def select_random_sequences(self, num_seqs: int) -> FastMSA:
"""Uses random sampling to subselect sequences from the MSA. Always
keeps the query sequence.
"""
if num_seqs >= self.depth:
return self
# Subselect random, always keeping the query sequence.
indices = np.sort(
np.append(
0, np.random.choice(self.depth - 1, num_seqs - 1, replace=False) + 1
)
)
msa = self.select_sequences(indices) # type: ignore
return msa
def pad_to_depth(self, depth: int) -> FastMSA:
if depth < self.depth:
raise ValueError(f"Cannot pad to depth {depth} when depth is {self.depth}")
elif depth == self.depth:
return self
num_to_add = depth - self.depth
array = np.pad(
self.array,
[(0, num_to_add), (0, 0)],
constant_values=ord("-") if self.array.dtype == np.uint8 else b"-",
)
headers = self.headers
if headers is not None:
headers = headers + [""] * num_to_add
return dataclasses.replace(self, array=array, headers=headers)
@classmethod
def concat(
cls,
msas: Sequence[FastMSA],
join_token: str | None = None,
allow_depth_mismatch: bool = False,
) -> FastMSA:
"""Concatenate a series of MSAs horizontally, along the sequence dimension."""
if not msas:
raise ValueError("Cannot concatenate an empty list of MSAs")
if join_token is not None and join_token != "":
raise NotImplementedError("join_token is not supported for FastMSA")
msa_depths = [msa.depth for msa in msas]
if len(set(msa_depths)) != 1:
if not allow_depth_mismatch:
raise ValueError("Depth mismatch in concatenating MSAs")
else:
max_depth = max(msa_depths)
msas = [msa.pad_to_depth(max_depth) for msa in msas]
headers = [
"|".join([str(h) for h in headers])
for headers in zip(
*(
msa.headers if msa.headers is not None else [""] * msa.depth
for msa in msas
)
)
]
array = np.concatenate([msa.array for msa in msas], axis=1)
return cls(array, headers)
def to_msa(self) -> MSA:
headers = (
self.headers
if self.headers is not None
else [f"seq{i}" for i in range(self.depth)]
)
entries = [
FastaEntry(header, b"".join(row).decode())
for header, row in zip(headers, self.array)
]
return MSA(entries)
@classmethod
def stack(
cls, msas: Sequence[FastMSA], remove_query_from_later_msas: bool = True
) -> FastMSA:
"""Stack a series of MSAs. Optionally remove the query from msas after the first."""
arrays = []
all_headers = []
for i, msa in enumerate(msas):
array = msa.array
headers = msa.headers
if i > 0 and remove_query_from_later_msas:
array = array[1:]
if headers is not None:
headers = headers[1:]
arrays.append(array)
if headers is not None:
all_headers.extend(headers)
return cls(np.concatenate(arrays, axis=0), all_headers)

83
esm/utils/parsing.py Normal file
View File

@@ -0,0 +1,83 @@
import io
from pathlib import Path
from typing import Generator, Iterable, NamedTuple
PathOrBuffer = str | Path | io.TextIOBase
FastaEntry = NamedTuple("FastaEntry", [("header", str), ("sequence", str)])
def parse_fasta(fasta_string: str) -> Generator[FastaEntry, None, None]:
"""
Parses a fasta file and yields FastaEntry objects
Args:
fasta_string: The fasta file as a string
Returns:
A generator of FastaEntry objects
"""
header = None
seq = []
num_sequences = 0
for line in fasta_string.splitlines():
if not line or line[0] == "#":
continue
if line.startswith(">"):
if header is not None:
yield FastaEntry(header, "".join(seq))
seq = []
header = line[1:].strip()
else:
seq.append(line)
if header is not None:
num_sequences += 1
yield FastaEntry(header, "".join(seq))
if num_sequences == 0:
raise ValueError("Found no sequences in input")
def read_sequences(path: PathOrBuffer) -> Generator[FastaEntry, None, None]:
# Uses duck typing to try and call the right method
# Doesn't use explicit isinstance check to support
# inputs that are not explicitly str/Path/TextIOBase but
# may support similar functionality
data = None # type: ignore
try:
if str(path).endswith(".gz"):
import gzip
data = gzip.open(path, "rt") # type: ignore
else:
try:
data = open(path) # type: ignore
except TypeError:
data: io.TextIOBase = path # type: ignore
yield from parse_fasta(data.read())
finally:
if data is not None:
data.close()
def read_first_sequence(path: PathOrBuffer) -> FastaEntry:
return next(iter(read_sequences(path)))
def write_sequences(sequences: Iterable[tuple[str, str]], path: PathOrBuffer) -> None:
needs_closing = False
handle = None
try:
try:
handle = open(path, "w") # type: ignore
needs_closing = True
except TypeError:
handle = path
has_prev = False
for header, seq in sequences:
if has_prev:
handle.write("\n") # type: ignore
handle.write(f">{header}\n{seq}") # type: ignore
has_prev = True
finally:
if needs_closing:
handle.close() # type: ignore

View File

@@ -0,0 +1,157 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields, replace
from typing import TypeVar
import numpy as np
from esm.utils.misc import concat_objects, slice_any_object
T = TypeVar("T")
@dataclass(frozen=True)
class SequentialDataclass(ABC):
"""
This is a builder on a dataclass that allows for automatic slicing and concatenation.
When representing multimodal data, we often have multiple datatypes which have sequence dimensions that are the same (e.g. the length of the protein).
When applying a transformation like a crop, we want to apply this to all tensors at the same time (e.g. crop the sequence, structure, and function).
We also have some fields that are not sequential (like an id, or data source), which we don't want to crop.
The SequentialDataclass abstracts this cropping away, allowing you to define dataclasses that implement `__len__`, `__getitem__` and `concat` automatically.
This is done through the `metadata` field, which can take 3 values:
`sequence` (bool): True or False, tells the dataclass whether this field is a sequential type. Default: False.
`sequence_dim` (int): Which dimension is the sequential dimension (e.g. for a list of inverse folded sequences, we want to index each sequence in the list, not the list itself). Default: 0.
`join_token` (Any): What token to use to join when concatenating elements. Default: None.
Example:
@dataclass(frozen=True)
class Foo(SequentialDataclass):
id: str
sequence: str = field(metadata={"sequence": True, "join_token": "|"})
tensor: torch.Tensor = field(metadata={"sequence": True, "join_token": torch.nan})
def __len__(self):
# Must implement the __len__ method
return len(self.sequence)
>>> foo = Foo(id="foo", sequence="ABCDE", tensor=torch.randn(5))
Foo(id='foo', sequence='ABCDE', tensor=tensor([ 0.0252, -0.3335, -0.5143, 0.0251, -1.0717]))
>>> foo[1:4]
Foo(id='foo', sequence='BCD', tensor=tensor([-0.3335, -0.5143, 0.0251]))
>>> foo[np.arange(5) < 3]
Foo(id='foo', sequence='ABC', tensor=tensor([ 0.0252, -0.3335, -0.5143]))
>>> Foo.concat([foo[:2], foo[3:]])
Foo(id='foo', sequence='AB|DE', tensor=tensor([ 0.0252, -0.3335, nan, 0.0251, -1.0717]))
# Trying to create a type where the sequence lengths do not match raises an error
>>> foo = Foo(id="foo", sequence="ABCDE", tensor=torch.randn(6))
ValueError: Mismatch in sequence length for field: tensor. Expected 5, received 6
"""
def __post_init__(self):
self._check_sequence_lengths_match()
@abstractmethod
def __len__(self):
raise NotImplementedError
def __getitem__(self, idx: int | list[int] | slice | np.ndarray):
updated_fields = {}
if isinstance(idx, int):
# make it so that things remain sequential
idx = [idx]
for fld in fields(self):
if fld.metadata.get("sequence", False):
# this is a sequence, should be the same length as all other sequences
sequence_dim = fld.metadata.get("sequence_dim", 0)
value = getattr(self, fld.name)
if value is None:
continue
match sequence_dim:
case 0:
# sequence is first dimension
value = getattr(self, fld.name)
value = slice_any_object(value, idx)
updated_fields[fld.name] = value
case 1:
new_value = [slice_any_object(item, idx) for item in value]
updated_fields[fld.name] = value.__class__(new_value)
case _:
raise NotImplementedError(
"Arbitrary slicing for different sequence length fields is not implemented"
)
return replace(self, **updated_fields)
def _check_sequence_lengths_match(self):
"""Checks if sequence lengths of all "sequence" fields match."""
for fld in fields(self):
if fld.metadata.get("sequence", False) and fld.name != "complex":
# this is a sequence, should be the same length as all other sequences
sequence_dim = fld.metadata.get("sequence_dim", 0)
value = getattr(self, fld.name)
if value is None:
continue
match sequence_dim:
case 0:
# sequence is first dimension
value = getattr(self, fld.name)
if len(value) != len(self):
raise ValueError(
f"Mismatch in sequence length for field: {fld.name}. Expected {len(self)}, received {len(value)}"
)
case 1:
for item in value:
if len(item) != len(self):
raise ValueError(
f"Mismatch in sequence length for field: {fld.name}. Expected {len(self)}, received {len(item)}"
)
case _:
raise NotImplementedError(
"Arbitrary matching for different sequence length fields is not implemented"
)
@classmethod
def concat(cls, items: list[T], **kwargs) -> T:
updated_fields = {}
for fld in fields(cls):
if fld.metadata.get("sequence", False):
# this is a sequence, should be the same length as all other sequences
sequence_dim = fld.metadata.get("sequence_dim", 0)
join_value = fld.metadata.get("join_token", None)
if getattr(items[0], fld.name) is None:
continue
values = [getattr(item, fld.name) for item in items]
match sequence_dim:
case 0:
# sequence is first dimension
value = concat_objects(values, join_value)
updated_fields[fld.name] = value
case 1:
new_value = [
concat_objects(item, join_value) for item in zip(*values)
]
updated_fields[fld.name] = getattr(
items[0], fld.name
).__class__(new_value)
case _:
raise NotImplementedError(
"Arbitrary joining for different sequence length fields is not implemented"
)
updated_fields.update(kwargs)
return replace(
items[0], # type: ignore
**updated_fields,
)

View File

@@ -0,0 +1,95 @@
from dataclasses import dataclass
from typing import Any, Sequence
import numpy as np
@dataclass
class Modification:
position: int # zero-indexed
ccd: str
@dataclass
class ProteinInput:
id: str | list[str]
sequence: str
modifications: list[Modification] | None = None
@dataclass
class RNAInput:
id: str | list[str]
sequence: str
modifications: list[Modification] | None = None
@dataclass
class DNAInput:
id: str | list[str]
sequence: str
modifications: list[Modification] | None = None
@dataclass
class LigandInput:
id: str | list[str]
smiles: str
ccd: list[str] | None = None
@dataclass
class DistogramConditioning:
chain_id: str
distogram: np.ndarray
@dataclass
class PocketConditioning:
binder_chain_id: str
contacts: list[tuple[str, int]]
@dataclass
class StructurePredictionInput:
sequences: Sequence[ProteinInput | RNAInput | DNAInput | LigandInput]
pocket: PocketConditioning | None = None
distogram_conditioning: list[DistogramConditioning] | None = None
def serialize_structure_prediction_input(all_atom_input: StructurePredictionInput):
def create_chain_data(seq_input, chain_type: str) -> dict[str, Any]:
chain_data: dict[str, Any] = {
"sequence": seq_input.sequence,
"id": seq_input.id,
"type": chain_type,
}
if hasattr(seq_input, "modifications") and seq_input.modifications:
mods = [
{"position": mod.position, "ccd": mod.ccd}
for mod in seq_input.modifications
]
chain_data["modifications"] = mods
return chain_data
sequences = []
for seq_input in all_atom_input.sequences:
if isinstance(seq_input, ProteinInput):
sequences.append(create_chain_data(seq_input, "protein"))
elif isinstance(seq_input, RNAInput):
sequences.append(create_chain_data(seq_input, "rna"))
elif isinstance(seq_input, DNAInput):
sequences.append(create_chain_data(seq_input, "dna"))
elif isinstance(seq_input, LigandInput):
sequences.append(
{
"smiles": seq_input.smiles,
"id": seq_input.id,
"ccd": seq_input.ccd,
"type": "ligand",
}
)
else:
raise ValueError(f"Unsupported sequence input type: {type(seq_input)}")
return {"sequences": sequences}

View File

@@ -264,7 +264,7 @@ def compute_lddt_ca(
if all_atom_pred_pos.dim() != 3:
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
all_atom_mask = all_atom_mask[..., ca_pos]
return compute_lddt(
all_atom_pred_pos,

View File

@@ -0,0 +1,938 @@
from __future__ import annotations
import io
import os
import re
from dataclasses import asdict, dataclass
from pathlib import Path
from subprocess import check_output
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any, List
import biotite.structure.io.pdbx as pdbx
import brotli
import msgpack
import numpy as np
import torch
from esm.utils import residue_constants
from esm.utils.structure.metrics import compute_lddt, compute_rmsd
from esm.utils.structure.protein_complex import ProteinComplex, ProteinComplexMetadata
@dataclass
class MolecularComplexResult:
"""Result of molecular complex folding"""
complex: MolecularComplex
plddt: torch.Tensor | None = None
ptm: float | None = None
iptm: float | None = None
pae: torch.Tensor | None = None
distogram: torch.Tensor | None = None
pair_chains_iptm: torch.Tensor | None = None
output_embedding_sequence: torch.Tensor | None = None
output_embedding_pair_pooled: torch.Tensor | None = None
@dataclass
class MolecularComplexMetadata:
"""Metadata for MolecularComplex objects."""
entity_lookup: dict[int, str]
chain_lookup: dict[int, str]
assembly_composition: dict[str, list[str]] | None = None
@dataclass
class Molecule:
"""Represents a single molecule/token within a MolecularComplex."""
token: str
token_idx: int
atom_positions: np.ndarray # [N_atoms, 3]
atom_elements: np.ndarray # [N_atoms] element strings
residue_type: int
molecule_type: int # PROTEIN=0, RNA=1, DNA=2, LIGAND=3
confidence: float
@dataclass(frozen=True)
class MolecularComplex:
"""
Dataclass representing a molecular complex with support for proteins, nucleic acids, and ligands.
Uses a flat atom representation with token-based sequence indexing, supporting all atom types
beyond the traditional atom37 protein representation.
"""
id: str
sequence: List[str] # Token sequence like ['MET', 'LYS', 'A', 'G', 'ATP']
# Flat atom arrays - simplified representation
atom_positions: np.ndarray # [N_atoms, 3] 3D coordinates
atom_elements: np.ndarray # [N_atoms] element strings
# Token-to-atom mapping for efficient access
token_to_atoms: np.ndarray # [N_tokens, 2] start/end indices into atoms array
# Confidence data
plddt: np.ndarray # Per-token confidence scores [N_tokens]
# Metadata
metadata: MolecularComplexMetadata
def __post_init__(self):
"""Validate array dimensions."""
n_tokens = len(self.sequence)
assert (
self.token_to_atoms.shape[0] == n_tokens
), f"token_to_atoms shape {self.token_to_atoms.shape} != {n_tokens} tokens"
assert (
self.plddt.shape[0] == n_tokens
), f"plddt shape {self.plddt.shape} != {n_tokens} tokens"
def __len__(self) -> int:
"""Return number of tokens."""
return len(self.sequence)
def __getitem__(self, idx: int) -> Molecule:
"""Access individual molecules/tokens by index."""
if idx >= len(self.sequence) or idx < 0:
raise IndexError(
f"Token index {idx} out of range for {len(self.sequence)} tokens"
)
token = self.sequence[idx]
start_atom, end_atom = self.token_to_atoms[idx]
# Extract atom data for this token
token_atom_positions = self.atom_positions[start_atom:end_atom]
token_atom_elements = self.atom_elements[start_atom:end_atom]
# Default values for residue/molecule type (would be extended based on actual implementation)
residue_type = 0 # Default to standard residue
molecule_type = 0 # Default to protein
return Molecule(
token=token,
token_idx=idx,
atom_positions=token_atom_positions,
atom_elements=token_atom_elements,
residue_type=residue_type,
molecule_type=molecule_type,
confidence=self.plddt[idx],
)
@property
def atom_coordinates(self) -> np.ndarray:
"""Get flat array of all atom coordinates [N_atoms, 3]."""
return self.atom_positions
# Conversion methods
@classmethod
def from_protein_complex(cls, pc: ProteinComplex) -> "MolecularComplex":
"""Convert a ProteinComplex to MolecularComplex.
Args:
pc: ProteinComplex object with atom37 representation
Returns:
MolecularComplex with flat atom arrays and token-based indexing
"""
from esm.utils import residue_constants
# Extract sequence without chain breaks
sequence_no_breaks = pc.sequence.replace("|", "")
sequence_tokens = [
residue_constants.restype_1to3.get(aa, "UNK") for aa in sequence_no_breaks
]
# Convert atom37 to flat arrays
flat_positions = []
flat_elements = []
token_to_atoms = []
atom_idx = 0
residue_idx = 0
for i, aa in enumerate(pc.sequence):
if aa == "|":
# Skip chain break tokens
continue
# Get atom37 positions and mask for this residue
res_positions = pc.atom37_positions[residue_idx] # [37, 3]
res_mask = pc.atom37_mask[residue_idx] # [37]
# Track start position for this token
token_start = atom_idx
# Process each atom type in atom37 representation
for atom_type_idx, atom_name in enumerate(residue_constants.atom_types):
if res_mask[atom_type_idx]: # Atom is present
# Add position
flat_positions.append(res_positions[atom_type_idx])
# Determine element from atom name
element = (
atom_name[0] if atom_name else "C"
) # First character is element
flat_elements.append(element)
atom_idx += 1
# Record token-to-atom mapping [start_idx, end_idx)
token_to_atoms.append([token_start, atom_idx])
residue_idx += 1
# Convert to numpy arrays
atom_positions = np.array(flat_positions, dtype=np.float32)
atom_elements = np.array(flat_elements, dtype=object)
token_to_atoms_array = np.array(token_to_atoms, dtype=np.int32)
# Extract confidence scores (skip chain breaks)
confidence_scores = []
residue_idx = 0
for aa in pc.sequence:
if aa != "|":
confidence_scores.append(pc.confidence[residue_idx])
residue_idx += 1
confidence_array = np.array(confidence_scores, dtype=np.float32)
# Create metadata - convert entity IDs to strings for MolecularComplexMetadata
entity_lookup_str = {k: str(v) for k, v in pc.metadata.entity_lookup.items()}
metadata = MolecularComplexMetadata(
entity_lookup=entity_lookup_str,
chain_lookup=pc.metadata.chain_lookup,
assembly_composition=pc.metadata.assembly_composition,
)
return cls(
id=pc.id,
sequence=sequence_tokens,
atom_positions=atom_positions,
atom_elements=atom_elements,
token_to_atoms=token_to_atoms_array,
plddt=confidence_array,
metadata=metadata,
)
def to_protein_complex(self) -> ProteinComplex:
"""Convert MolecularComplex back to ProteinComplex format.
Extracts only protein tokens and converts from flat atom representation
back to atom37 format used by ProteinComplex.
Returns:
ProteinComplex with protein residues only, excluding ligands/nucleic acids
"""
from esm.utils import residue_constants
# No need for element mapping - already using element characters
# Filter for protein tokens only (skip ligands, nucleic acids)
protein_tokens = []
protein_indices = []
for i, token in enumerate(self.sequence):
# Check if token is a standard 3-letter amino acid code
if token in residue_constants.restype_3to1:
protein_tokens.append(token)
protein_indices.append(i)
if not protein_tokens:
raise ValueError("No protein tokens found in MolecularComplex")
n_residues = len(protein_tokens)
# Initialize atom37 arrays
atom37_positions = np.full((n_residues, 37, 3), np.nan, dtype=np.float32)
atom37_mask = np.zeros((n_residues, 37), dtype=bool)
# Convert tokens back to single-letter sequence
single_letter_sequence = "".join(
[residue_constants.restype_3to1[token] for token in protein_tokens]
)
# Extract confidence scores for protein residues only
protein_confidence = self.plddt[protein_indices]
# Convert flat atoms back to atom37 representation
for res_idx, token_idx in enumerate(protein_indices):
token = self.sequence[token_idx]
start_atom, end_atom = self.token_to_atoms[token_idx]
# Get atom data for this residue
res_atom_positions = self.atom_positions[start_atom:end_atom]
# Reconstruct atom37 representation by exactly reversing the forward conversion logic
# In from_protein_complex, atoms are added in atom_types order if present in mask
# So we need to reconstruct the mask and positions in the same order
atom_count = 0
for atom_type_idx, atom_name in enumerate(residue_constants.atom_types):
# Check if this atom type exists for this residue and was present
residue_atoms = residue_constants.residue_atoms.get(token, [])
if atom_name in residue_atoms:
# This atom type exists for this residue, so it should have been included
if atom_count < len(res_atom_positions):
atom37_positions[res_idx, atom_type_idx] = res_atom_positions[
atom_count
]
atom37_mask[res_idx, atom_type_idx] = True
atom_count += 1
# Create other required arrays for ProteinComplex
# For simplicity, assume all protein residues belong to the same entity/chain
entity_id = np.zeros(n_residues, dtype=np.int64)
chain_id = np.zeros(n_residues, dtype=np.int64)
sym_id = np.zeros(n_residues, dtype=np.int64)
residue_index = np.arange(1, n_residues + 1, dtype=np.int64)
insertion_code = np.array([""] * n_residues, dtype=object)
# Create simplified protein complex metadata
# Map the first entity/chain from molecular complex metadata
protein_metadata = ProteinComplexMetadata(
entity_lookup={0: 1}, # Single entity (int for ProteinComplexMetadata)
chain_lookup={0: "A"}, # Single chain
assembly_composition=self.metadata.assembly_composition,
)
return ProteinComplex(
id=self.id,
sequence=single_letter_sequence,
entity_id=entity_id,
chain_id=chain_id,
sym_id=sym_id,
residue_index=residue_index,
insertion_code=insertion_code,
atom37_positions=atom37_positions,
atom37_mask=atom37_mask,
confidence=protein_confidence,
metadata=protein_metadata,
)
@classmethod
def from_mmcif(cls, inp: str, id: str | None = None) -> "MolecularComplex":
"""Read MolecularComplex from mmcif file or string.
Args:
inp: Path to mmCIF file or mmCIF content as string
id: Optional identifier to assign to the complex
Returns:
MolecularComplex with all molecules (proteins, ligands, nucleic acids)
"""
from io import StringIO
# Check if input is a file path or mmCIF string content
if os.path.exists(inp):
# Input is a file path
mmcif_file = pdbx.CIFFile.read(inp)
else:
# Input is mmCIF string content
mmcif_file = pdbx.CIFFile.read(StringIO(inp))
# Get structure - handle missing model information gracefully
try:
structure = pdbx.get_structure(mmcif_file, model=1)
except (KeyError, ValueError):
# Fallback for mmCIF files without model information
try:
structure = pdbx.get_structure(mmcif_file)
except Exception:
# Last resort: use the first available model or all atoms
structure = pdbx.get_structure(mmcif_file, model=None)
# Type hint for pyright - structure is an AtomArray which is iterable
if TYPE_CHECKING:
structure: Any = structure
# Get entity information from mmCIF
entity_info = {}
try:
# Access the first block in CIFFile
block = mmcif_file[0]
if "entity" in block:
entity_category = block["entity"]
if "id" in entity_category and "type" in entity_category:
entity_ids = entity_category["id"]
entity_types = entity_category["type"]
# Convert CIFColumn to list for iteration
if hasattr(entity_ids, "__iter__") and hasattr(
entity_types, "__iter__"
):
# Type annotation to help pyright understand these are iterable
entity_ids_list = list(entity_ids) # type: ignore
entity_types_list = list(entity_types) # type: ignore
for eid, etype in zip(entity_ids_list, entity_types_list):
entity_info[eid] = etype
except Exception:
pass
# Initialize arrays for flat atom representation
sequence_tokens = []
flat_positions = []
flat_elements = []
token_to_atoms = []
confidence_scores = []
atom_idx = 0
# Group atoms by chain and residue
chain_residue_groups = {}
for atom in structure:
chain_id = atom.chain_id
res_id = atom.res_id
res_name = atom.res_name
if chain_id not in chain_residue_groups:
chain_residue_groups[chain_id] = {}
if res_id not in chain_residue_groups[chain_id]:
chain_residue_groups[chain_id][res_id] = {
"atoms": [],
"res_name": res_name,
"is_hetero": atom.hetero,
}
chain_residue_groups[chain_id][res_id]["atoms"].append(atom)
# Process each chain and residue
for chain_id in sorted(chain_residue_groups.keys()):
residues = chain_residue_groups[chain_id]
for res_id in sorted(residues.keys()):
residue_data = residues[res_id]
res_name = residue_data["res_name"]
atoms = residue_data["atoms"]
is_hetero = residue_data["is_hetero"]
# Skip water molecules
if res_name == "HOH":
continue
# Determine token name
if not is_hetero and res_name in residue_constants.restype_3to1:
# Standard amino acid
token_name = res_name
elif res_name in ["A", "T", "G", "C", "U", "DA", "DT", "DG", "DC"]:
# Nucleotide
token_name = res_name
else:
# Ligand or other molecule
token_name = res_name
sequence_tokens.append(token_name)
token_start = atom_idx
# Add all atoms from this residue
for atom in atoms:
flat_positions.append(atom.coord)
# Get element character
element = atom.element
flat_elements.append(element)
atom_idx += 1
# Record token-to-atom mapping
token_to_atoms.append([token_start, atom_idx])
# Add confidence score (B-factor if available, otherwise 1.0)
bfactor = getattr(atoms[0], "b_factor", 50.0) if atoms else 50.0
confidence_scores.append(min(bfactor / 100.0, 1.0))
# Convert to numpy arrays
if not flat_positions:
# Create minimal arrays if no atoms found
atom_positions = np.zeros((0, 3), dtype=np.float32)
atom_elements = np.zeros(0, dtype=object)
token_to_atoms_array = np.zeros((len(sequence_tokens), 2), dtype=np.int32)
else:
atom_positions = np.array(flat_positions, dtype=np.float32)
atom_elements = np.array(flat_elements, dtype=object)
token_to_atoms_array = np.array(token_to_atoms, dtype=np.int32)
confidence_array = np.array(confidence_scores, dtype=np.float32)
# Create metadata
metadata = MolecularComplexMetadata(
entity_lookup=entity_info,
chain_lookup={
i: chain_id for i, chain_id in enumerate(chain_residue_groups.keys())
},
assembly_composition=None,
)
# Set complex ID - if input was a path, use the stem; otherwise use default
if os.path.exists(inp):
complex_id = id or Path(inp).stem
else:
complex_id = id or "complex_from_string"
return cls(
id=complex_id,
sequence=sequence_tokens,
atom_positions=atom_positions,
atom_elements=atom_elements,
token_to_atoms=token_to_atoms_array,
plddt=confidence_array,
metadata=metadata,
)
def to_mmcif(self) -> str:
"""Write MolecularComplex to mmcif string.
Returns:
String representation of the complex in mmCIF format
"""
# No need for element mapping - already using element characters
lines = []
# Header
lines.append(f"data_{self.id}")
lines.append("#")
lines.append(f"_entry.id {self.id}")
lines.append("#")
# Structure metadata
lines.append("_struct.entry_id {}".format(self.id))
lines.append("_struct.title 'Protein Structure'")
lines.append("#")
# Entity information
entity_id = 1
chain_counter = 0
lines.append("loop_")
lines.append("_entity.id")
lines.append("_entity.type")
lines.append("_entity.pdbx_description")
# Determine entities based on sequence
protein_tokens = []
other_tokens = []
for i, token in enumerate(self.sequence):
if token in residue_constants.restype_3to1:
protein_tokens.append((i, token))
else:
other_tokens.append((i, token))
if protein_tokens:
lines.append(f"{entity_id} polymer 'Protein chain'")
entity_id += 1
for token in set(token for _, token in other_tokens):
lines.append(f"{entity_id} non-polymer 'Ligand {token}'")
entity_id += 1
lines.append("#")
# Chain assignments
lines.append("loop_")
lines.append("_struct_asym.id")
lines.append("_struct_asym.entity_id")
chain_id = "A"
if protein_tokens:
lines.append(f"{chain_id} 1")
chain_counter += 1
chain_id = chr(ord(chain_id) + 1)
entity_id = 2
for token in set(token for _, token in other_tokens):
lines.append(f"{chain_id} {entity_id}")
entity_id += 1
chain_counter += 1
if chain_counter < 26:
chain_id = chr(ord(chain_id) + 1)
lines.append("#")
# Atom site information
lines.append("loop_")
lines.append("_atom_site.group_PDB")
lines.append("_atom_site.id")
lines.append("_atom_site.type_symbol")
lines.append("_atom_site.label_atom_id")
lines.append("_atom_site.label_alt_id")
lines.append("_atom_site.label_comp_id")
lines.append("_atom_site.label_asym_id")
lines.append("_atom_site.label_entity_id")
lines.append("_atom_site.label_seq_id")
lines.append("_atom_site.pdbx_PDB_ins_code")
lines.append("_atom_site.Cartn_x")
lines.append("_atom_site.Cartn_y")
lines.append("_atom_site.Cartn_z")
lines.append("_atom_site.occupancy")
lines.append("_atom_site.B_iso_or_equiv")
lines.append("_atom_site.pdbx_PDB_model_num")
lines.append("_atom_site.auth_seq_id")
lines.append("_atom_site.auth_comp_id")
lines.append("_atom_site.auth_asym_id")
lines.append("_atom_site.auth_atom_id")
atom_id = 1
seq_id = 1
chain_id = "A"
entity_id = 1
for token_idx, token in enumerate(self.sequence):
start_atom, end_atom = self.token_to_atoms[token_idx]
# Determine if this is a protein residue or ligand
is_protein = token in residue_constants.restype_3to1
group_pdb = "ATOM" if is_protein else "HETATM"
current_entity_id = 1 if is_protein else 2 # Simplified entity assignment
current_chain_id = "A" if is_protein else "B" # Simplified chain assignment
# Create atom names for this token
atom_names = []
if is_protein:
# Use standard protein atom names
res_atoms = residue_constants.residue_atoms.get(
token, ["N", "CA", "C", "O"]
)
atom_names = res_atoms[: end_atom - start_atom]
else:
# Generate generic atom names for ligands
for i in range(end_atom - start_atom):
atom_names.append(f"C{i+1}")
# Pad atom names if needed
while len(atom_names) < (end_atom - start_atom):
atom_names.append(f"X{len(atom_names)+1}")
# Write atoms for this token
for atom_idx_in_token, global_atom_idx in enumerate(
range(start_atom, end_atom)
):
pos = self.atom_positions[global_atom_idx]
element_char = self.atom_elements[global_atom_idx]
element_symbol = element_char if isinstance(element_char, str) else "C"
atom_name = (
atom_names[atom_idx_in_token]
if atom_idx_in_token < len(atom_names)
else f"X{atom_idx_in_token+1}"
)
# Format atom site line
bfactor = (
self.plddt[token_idx] * 100.0
if len(self.plddt) > token_idx
else 50.0
)
line = (
f"{group_pdb:<6} {atom_id:>5} {element_symbol:<2} {atom_name:<4} . "
f"{token:<3} {current_chain_id} {current_entity_id} {seq_id:>3} ? "
f"{pos[0]:>8.3f} {pos[1]:>8.3f} {pos[2]:>8.3f} 1.00 {bfactor:>6.2f} 1 "
f"{seq_id:>3} {token:<3} {current_chain_id} {atom_name:<4}"
)
lines.append(line)
atom_id += 1
seq_id += 1
lines.append("#")
return "\n".join(lines)
def dockq(self, native: "MolecularComplex") -> Any:
"""Compute DockQ score against native structure.
Args:
native: Native MolecularComplex to compute DockQ against
Returns:
DockQ result containing score and alignment information
"""
# Imports moved to top of file
# Convert both complexes to ProteinComplex format for DockQ computation
# This extracts only the protein portion and converts to PDB format
try:
self_pc = self.to_protein_complex()
native_pc = native.to_protein_complex()
except ValueError as e:
raise ValueError(
f"Cannot convert MolecularComplex to ProteinComplex for DockQ: {e}"
)
# Normalize chain IDs for PDB compatibility
self_pc = self_pc.normalize_chain_ids_for_pdb()
native_pc = native_pc.normalize_chain_ids_for_pdb()
# Use the existing ProteinComplex.dockq() method
try:
dockq_result = self_pc.dockq(native_pc)
return dockq_result
except Exception:
# Fallback to manual DockQ computation if ProteinComplex.dockq() fails
return self._compute_dockq_manual(native)
def _compute_dockq_manual(self, native: "MolecularComplex") -> Any:
"""Manual DockQ computation fallback."""
# Imports moved to top of file
# Convert both complexes to ProteinComplex format
try:
self_pc = self.to_protein_complex()
native_pc = native.to_protein_complex()
except ValueError as e:
raise ValueError(
f"Cannot convert MolecularComplex to ProteinComplex for DockQ: {e}"
)
# Normalize chain IDs for PDB compatibility
self_pc = self_pc.normalize_chain_ids_for_pdb()
native_pc = native_pc.normalize_chain_ids_for_pdb()
# Write temporary PDB files and run DockQ
with TemporaryDirectory() as tdir:
dir_path = Path(tdir)
self_pdb = dir_path / "self.pdb"
native_pdb = dir_path / "native.pdb"
# Write PDB files
self_pc.to_pdb(self_pdb)
native_pc.to_pdb(native_pdb)
# Run DockQ
try:
output = check_output(["DockQ", str(self_pdb), str(native_pdb)])
output_text = output.decode()
# Parse DockQ output
lines = output_text.split("\n")
# Find the total DockQ score
dockq_score = None
for line in lines:
if "Total DockQ" in line:
match = re.search(r"Total DockQ.*: ([\d.]+)", line)
if match:
dockq_score = float(match.group(1))
break
if dockq_score is None:
# Try to find individual DockQ scores
for line in lines:
if line.startswith("DockQ") and ":" in line:
try:
dockq_score = float(line.split(":")[1].strip())
break
except (ValueError, IndexError):
continue
if dockq_score is None:
raise ValueError("Could not parse DockQ score from output")
# Return a simple result structure
return {
"total_dockq": dockq_score,
"raw_output": output_text,
"aligned": self, # Return self as aligned structure
}
except FileNotFoundError:
raise RuntimeError(
"DockQ is not installed. Please install DockQ to use this method."
)
except Exception as e:
raise RuntimeError(f"DockQ computation failed: {e}")
def rmsd(self, target: "MolecularComplex", **kwargs) -> float:
"""Compute RMSD against target structure.
Args:
target: Target MolecularComplex to compute RMSD against
**kwargs: Additional arguments passed to compute_rmsd
Returns:
float: RMSD value between the two structures
"""
# Imports moved to top of file
# Ensure both complexes have the same number of tokens
if len(self) != len(target):
raise ValueError(
f"Complexes must have the same number of tokens: {len(self)} vs {len(target)}"
)
# Extract center positions for each token (using centroid of atoms)
mobile_coords = []
target_coords = []
atom_mask = []
for i in range(len(self)):
# Get atom positions for this token
mobile_start, mobile_end = self.token_to_atoms[i]
target_start, target_end = target.token_to_atoms[i]
# Extract atom positions
mobile_atoms = self.atom_positions[mobile_start:mobile_end]
target_atoms = target.atom_positions[target_start:target_end]
# Check if both tokens have atoms
if len(mobile_atoms) == 0 or len(target_atoms) == 0:
# Skip tokens with no atoms
continue
# For simplicity, use the centroid of atoms as the representative position
mobile_center = mobile_atoms.mean(axis=0)
target_center = target_atoms.mean(axis=0)
mobile_coords.append(mobile_center)
target_coords.append(target_center)
atom_mask.append(True)
if len(mobile_coords) == 0:
raise ValueError("No valid atoms found for RMSD computation")
# Convert to tensors
mobile_tensor = torch.from_numpy(np.stack(mobile_coords, axis=0)).unsqueeze(
0
) # [1, N, 3]
target_tensor = torch.from_numpy(np.stack(target_coords, axis=0)).unsqueeze(
0
) # [1, N, 3]
mask_tensor = torch.tensor(atom_mask, dtype=torch.bool).unsqueeze(0) # [1, N]
# Compute RMSD using existing infrastructure
rmsd_value = compute_rmsd(
mobile=mobile_tensor,
target=target_tensor,
atom_exists_mask=mask_tensor,
reduction="batch",
**kwargs,
)
return float(rmsd_value)
def lddt_ca(self, target: "MolecularComplex", **kwargs) -> float:
"""Compute LDDT score against target structure.
Args:
target: Target MolecularComplex to compute LDDT against
**kwargs: Additional arguments passed to compute_lddt
Returns:
float: LDDT value between the two structures
"""
# Imports moved to top of file
# Ensure both complexes have the same number of tokens
if len(self) != len(target):
raise ValueError(
f"Complexes must have the same number of tokens: {len(self)} vs {len(target)}"
)
# Extract center positions for each token (using centroid of atoms)
mobile_coords = []
target_coords = []
atom_mask = []
for i in range(len(self)):
# Get atom positions for this token
mobile_start, mobile_end = self.token_to_atoms[i]
target_start, target_end = target.token_to_atoms[i]
# Extract atom positions
mobile_atoms = self.atom_positions[mobile_start:mobile_end]
target_atoms = target.atom_positions[target_start:target_end]
# Check if both tokens have atoms
if len(mobile_atoms) == 0 or len(target_atoms) == 0:
# Skip tokens with no atoms
mobile_coords.append(np.full(3, np.nan))
target_coords.append(np.full(3, np.nan))
atom_mask.append(False)
continue
# For simplicity, use the centroid of atoms as the representative position
mobile_center = mobile_atoms.mean(axis=0)
target_center = target_atoms.mean(axis=0)
mobile_coords.append(mobile_center)
target_coords.append(target_center)
atom_mask.append(True)
if not any(atom_mask):
raise ValueError("No valid atoms found for LDDT computation")
# Convert to tensors
mobile_tensor = torch.from_numpy(np.stack(mobile_coords, axis=0)).unsqueeze(
0
) # [1, N, 3]
target_tensor = torch.from_numpy(np.stack(target_coords, axis=0)).unsqueeze(
0
) # [1, N, 3]
mask_tensor = torch.tensor(atom_mask, dtype=torch.bool).unsqueeze(0) # [1, N]
# Compute LDDT using existing infrastructure
lddt_value = compute_lddt(
all_atom_pred_pos=mobile_tensor,
all_atom_positions=target_tensor,
all_atom_mask=mask_tensor,
per_residue=False, # Return overall LDDT score
**kwargs,
)
return float(lddt_value)
def state_dict(self):
"""This state dict is optimized for storage, so it turns things to fp16 whenever
possible and converts numpy arrays to lists for JSON serialization.
"""
dct = {k: v for k, v in vars(self).items()}
for k, v in dct.items():
if isinstance(v, np.ndarray):
match v.dtype:
case np.int64:
dct[k] = v.astype(np.int32).tolist()
case np.float64 | np.float32:
dct[k] = v.astype(np.float16).tolist()
case _:
dct[k] = v.tolist()
elif isinstance(v, MolecularComplexMetadata):
dct[k] = asdict(v)
return dct
def to_blob(self) -> bytes:
return brotli.compress(msgpack.dumps(self.state_dict()), quality=5)
@classmethod
def from_state_dict(cls, dct):
for k, v in dct.items():
if isinstance(v, list) and k in [
"atom_positions",
"atom_elements",
"token_to_atoms",
"plddt",
]:
dct[k] = np.array(v)
for k, v in dct.items():
if isinstance(v, np.ndarray):
if k in ["atom_positions", "plddt"]:
dct[k] = v.astype(np.float32)
elif k in ["token_to_atoms"]:
dct[k] = v.astype(np.int32)
dct["metadata"] = MolecularComplexMetadata(**dct["metadata"])
return cls(**dct)
@classmethod
def from_blob(cls, input: Path | str | io.BytesIO | bytes):
match input:
case Path() | str():
bytes = Path(input).read_bytes()
case io.BytesIO():
bytes = input.getvalue()
case _:
bytes = input
return cls.from_state_dict(
msgpack.loads(brotli.decompress(bytes), strict_map_key=False)
)

View File

@@ -377,11 +377,14 @@ class ProteinComplex:
assert self.metadata.mmcif is not None
return get_assembly_fast(self.metadata.mmcif, assembly_id=id)
def state_dict(self, backbone_only=False):
def state_dict(self, backbone_only=False, json_serializable=False):
"""This state dict is optimized for storage, so it turns things to fp16 whenever
possible. Note that we also only support int32 residue indices, I'm hoping we don't
need more than 2**32 residues..."""
dct = {k: v for k, v in vars(self).items()}
if backbone_only:
dct["atom37_mask"][:, 3:] = False
dct["atom37_positions"] = dct["atom37_positions"][dct["atom37_mask"]]
for k, v in dct.items():
if isinstance(v, np.ndarray):
match v.dtype:
@@ -391,9 +394,10 @@ class ProteinComplex:
dct[k] = v.astype(np.float16)
case _:
pass
if json_serializable:
dct[k] = v.tolist()
elif isinstance(v, ProteinComplexMetadata):
dct[k] = asdict(v)
dct["atom37_positions"] = dct["atom37_positions"][dct["atom37_mask"]]
dct["metadata"]["mmcif"] = None
# These can be populated with non-serializable objects and are not needed for reconstruction
dct.pop("atoms", None)
@@ -406,6 +410,10 @@ class ProteinComplex:
@classmethod
def from_state_dict(cls, dct):
for k, v in dct.items():
if isinstance(v, list):
dct[k] = np.array(v)
atom37 = np.full((*dct["atom37_mask"].shape, 3), np.nan)
atom37[dct["atom37_mask"]] = dct["atom37_positions"]
dct["atom37_positions"] = atom37

45
esm/utils/system.py Normal file
View File

@@ -0,0 +1,45 @@
import io
import subprocess
import typing as T
from pathlib import Path
PathLike = T.Union[str, Path]
PathOrBuffer = T.Union[PathLike, io.StringIO]
def run_subprocess_with_errorcheck(
*popenargs,
capture_output: bool = False,
quiet: bool = False,
env: dict[str, str] | None = None,
shell: bool = False,
executable: str | None = None,
**kws,
) -> subprocess.CompletedProcess:
"""A command similar to subprocess.run, however the errormessage will
contain the stderr when using this function. This makes it significantly
easier to diagnose issues.
"""
try:
if capture_output:
stdout = subprocess.PIPE
elif quiet:
stdout = subprocess.DEVNULL
else:
stdout = None
p = subprocess.run(
*popenargs,
stderr=subprocess.PIPE,
stdout=stdout,
check=True,
env=env,
shell=shell,
executable=executable,
**kws,
)
except subprocess.CalledProcessError as e:
raise RuntimeError(
f"Command failed with errorcode {e.returncode}." f"\n\n{e.stderr.decode()}"
)
return p

View File

@@ -1726,8 +1726,8 @@ packages:
requires_python: '>=3.8'
- pypi: ./
name: esm
version: 3.2.2
sha256: c14e2546bda5f0910c14acfabb7ea334e7171905c6799b43178f0420a92d6f3e
version: 3.2.2.post2
sha256: 3f59a2977c85d35b4b1353902fa90e35d02acbabe6ffb506727bd406ec987ad1
requires_dist:
- torch>=2.2.0
- torchvision

View File

@@ -1,6 +1,6 @@
[project]
name = "esm"
version = "3.2.2"
version = "3.2.2.post2"
description = "EvolutionaryScale open model repository"
readme = "README.md"
requires-python = ">=3.12,<3.13"
@@ -45,7 +45,6 @@ dependencies = [
"pygtrie",
"dna_features_viewer",
]
# Pytest
[tool.pytest.ini_options]
addopts = """

View File

@@ -1,3 +1,2 @@
esm
esm >=3.2.1post1,<4.0.0
pytest
httpx # TODO(williamxi): Remove this after the esm repo is fixed

View File

@@ -1,6 +1,7 @@
import os
import pytest
import torch
from esm.sdk import client # pyright: ignore
from esm.sdk.api import ( # pyright: ignore
@@ -37,6 +38,8 @@ def test_oss_esm3_client():
logits_config = LogitsConfig(sequence=True, return_embeddings=True)
result = esm3_client.logits(input=encoded_protein, config=logits_config)
assert isinstance(result, LogitsOutput)
assert result.logits is not None
assert isinstance(result.logits.sequence, torch.Tensor)
sampling_config = SamplingConfig(sequence=SamplingTrackConfig(temperature=0.1))
result = esm3_client.forward_and_sample(
@@ -69,6 +72,8 @@ def test_oss_esmc_client():
)
result = esmc_client.logits(input=encoded_protein, config=logits_config)
assert isinstance(result, LogitsOutput)
assert result.logits is not None
assert isinstance(result.logits.sequence, torch.Tensor)
@pytest.mark.sdk