allow seqlen > 800

This commit is contained in:
Ishaan Mathur
2025-11-24 20:53:20 +00:00
parent 709d3e603b
commit 95c3c0281d
12 changed files with 503 additions and 23 deletions

View File

@@ -1 +1 @@
__version__ = "3.2.3"
__version__ = "3.2.4"

View File

@@ -489,6 +489,7 @@ class ESM3(nn.Module, ESM3InferenceClient):
function=function_tokens,
residue_annotations=residue_annotation_tokens,
coordinates=coordinates,
potential_sequence_of_concern=input.potential_sequence_of_concern,
).to(next(self.parameters()).device)
def decode(self, input: ESMProteinTensor) -> ESMProtein:

View File

@@ -178,9 +178,10 @@ class ESMC(nn.Module, ESMCInferenceClient):
if input.sequence is not None:
sequence_tokens = self._tokenize([input.sequence])[0]
return ESMProteinTensor(sequence=sequence_tokens).to(
next(self.parameters()).device
)
return ESMProteinTensor(
sequence=sequence_tokens,
potential_sequence_of_concern=input.potential_sequence_of_concern,
).to(next(self.parameters()).device)
def decode(self, input: ESMProteinTensor) -> ESMProtein:
input = attr.evolve(input) # Make a copy

View File

@@ -87,6 +87,7 @@ class ESMProtein(ProteinType):
sasa=protein_chain.sasa().tolist(),
function_annotations=None,
coordinates=torch.tensor(protein_chain.atom37_positions),
plddt=torch.tensor(protein_chain.confidence),
)
else:
return ESMProtein(
@@ -95,6 +96,7 @@ class ESMProtein(ProteinType):
sasa=None,
function_annotations=None,
coordinates=torch.tensor(protein_chain.atom37_positions),
plddt=torch.tensor(protein_chain.confidence),
)
@classmethod
@@ -114,6 +116,7 @@ class ESMProtein(ProteinType):
coordinates=torch.tensor(
protein_complex.atom37_positions, dtype=torch.float32
),
plddt=torch.tensor(protein_complex.confidence),
)
def to_pdb(self, pdb_path: PathOrBuffer) -> None:
@@ -194,6 +197,9 @@ class ESMProtein(ProteinType):
chain_id=gt_chains[i].chain_id
if gt_chains is not None
else SINGLE_LETTER_CHAIN_IDS[i],
residue_index=self.residue_index[start:end]
if self.residue_index is not None
else None,
entity_id=gt_chains[i].entity_id if gt_chains is not None else None,
confidence=plddt[start:end] if plddt is not None else None,
)

View File

@@ -1,9 +1,14 @@
from typing import Any
import asyncio
import time
from abc import ABC, abstractmethod
from contextlib import suppress
from typing import Any, Generic, TypeVar
from urllib.parse import urljoin
import httpx
from esm.sdk.api import ESMProteinError
from esm.sdk.retry import retry_decorator
from esm.utils.decoding import assemble_message
@@ -72,6 +77,7 @@ class _BaseForgeInferenceClient:
request: dict[str, Any],
potential_sequence_of_concern: bool | None = None,
return_bytes: bool = False,
disable_cache: bool = False,
headers: dict[str, str] = {},
) -> tuple[dict[str, Any], dict[str, str]]:
if potential_sequence_of_concern is not None:
@@ -80,6 +86,8 @@ class _BaseForgeInferenceClient:
headers = {**self.headers, **headers}
if return_bytes:
headers["return-bytes"] = "true"
if disable_cache:
headers["X-Disable-Cache"] = "true"
return request, headers
def prepare_data(self, response, endpoint: str) -> dict[str, Any]:
@@ -108,11 +116,16 @@ class _BaseForgeInferenceClient:
potential_sequence_of_concern: bool | None = None,
params: dict[str, Any] = {},
headers: dict[str, str] = {},
disable_cache: bool = False,
return_bytes: bool = False,
):
try:
request, headers = self.prepare_request(
request, potential_sequence_of_concern, return_bytes, headers
request,
potential_sequence_of_concern,
return_bytes,
disable_cache,
headers,
)
response = await self.async_client.post(
url=urljoin(self.url, f"/api/v1/{endpoint}"),
@@ -139,10 +152,15 @@ class _BaseForgeInferenceClient:
params: dict[str, Any] = {},
headers: dict[str, str] = {},
return_bytes: bool = False,
disable_cache: bool = False,
):
try:
request, headers = self.prepare_request(
request, potential_sequence_of_concern, return_bytes, headers
request,
potential_sequence_of_concern,
return_bytes,
disable_cache,
headers,
)
response = self.client.post(
url=urljoin(self.url, f"/api/v1/{endpoint}"),
@@ -160,3 +178,264 @@ class _BaseForgeInferenceClient:
error_code=500,
error_msg=f"Failed to submit request to {endpoint}. Error: {str(e)}",
)
class _BaseForgeBatchClient(_BaseForgeInferenceClient):
"""
A Python client for the protein folding batch API.
"""
def __init__(
self,
url: str = "https://forge.evolutionaryscale.ai",
token: str = "",
request_timeout: int | None = None,
min_retry_wait: int = 1,
max_retry_wait: int = 10,
max_retry_attempts: int = 5,
poll_interval: int = 2,
):
super().__init__(
model="", # model is not used in batch client
url=url,
token=token,
request_timeout=request_timeout,
min_retry_wait=min_retry_wait,
max_retry_wait=max_retry_wait,
max_retry_attempts=max_retry_attempts,
)
# How often to poll for status
self.poll_interval = poll_interval
@retry_decorator
def submit(
self, endpoint: str, payload: list[dict[str, Any]], disable_cache: bool = False
) -> str:
response_data = self._post(
"batch/submit",
{"endpoint": endpoint, "payload": payload},
disable_cache=disable_cache,
)
task_id = response_data.get("task_id")
if not task_id:
raise ESMProteinError(
error_code=500, error_msg="API did not return a valid task_id."
)
return task_id
@retry_decorator
async def async_submit(
self, endpoint: str, payload: list[dict[str, Any]], disable_cache: bool = False
) -> str:
response_data = await self._async_post(
"batch/submit",
{"endpoint": endpoint, "payload": payload},
disable_cache=disable_cache,
)
task_id = response_data.get("task_id")
if not task_id:
raise ESMProteinError(
error_code=500, error_msg="API did not return a valid task_id."
)
return task_id
def cancel(self, task_id: str) -> dict[str, Any]:
return self._post("batch/cancel", {"task_id": task_id})
async def async_cancel(self, task_id: str) -> dict[str, Any]:
return await self._async_post("batch/cancel", {"task_id": task_id})
@retry_decorator
def get_status(self, task_id: str) -> dict[str, Any]:
return self._post("batch/status", {"task_id": task_id})
@retry_decorator
async def async_get_status(self, task_id: str) -> dict[str, Any]:
return await self._async_post("batch/status", {"task_id": task_id})
def wait_for_completion(self, task_id: str, timeout: int) -> dict:
start_time = time.time()
while time.time() - start_time < timeout:
response = self.get_status(task_id)
job_status = response.get("status")
if job_status == "done":
return response
elif job_status == "cancelled":
raise ESMProteinError(
error_code=500, error_msg=f"Job {task_id} cancelled."
)
elif job_status == "failed":
raise ESMProteinError(
error_code=500,
error_msg=f"Job {task_id} failed with error: '{response.get('error')}'.",
)
time.sleep(self.poll_interval)
raise ESMProteinError(
error_code=500,
error_msg=f"Job {task_id} timed out after {timeout} seconds.",
)
async def async_wait_for_completion(self, task_id: str, timeout: int) -> dict:
start_time = time.time()
while time.time() - start_time < timeout:
response = await self.async_get_status(task_id)
job_status = response.get("status")
if job_status == "done":
return response
elif job_status == "cancelled":
raise ESMProteinError(
error_code=500, error_msg=f"Job {task_id} cancelled."
)
elif job_status == "failed":
raise ESMProteinError(
error_code=500,
error_msg=f"Job {task_id} failed with error: '{response.get('error')}'.",
)
await asyncio.sleep(self.poll_interval)
raise ESMProteinError(
error_code=500,
error_msg=f"Job {task_id} timed out after {timeout} seconds.",
)
@retry_decorator
def get_result_from_s3(self, s3_url: str) -> dict[str, Any]:
"""Downloads the result JSON from a pre-signed S3 URL."""
try:
response = self.client.get(s3_url)
response.raise_for_status()
return response.json()
except Exception as e:
raise ESMProteinError(
error_code=500,
error_msg=f"Failed to download result from S3 URL: {s3_url}. Error: {str(e)}",
)
@retry_decorator
async def async_get_result_from_s3(self, s3_url: str) -> dict[str, Any]:
"""Asynchronously downloads the result JSON from a pre-signed S3 URL."""
try:
response = await self.async_client.get(s3_url)
response.raise_for_status()
return response.json()
except Exception as e:
raise ESMProteinError(
error_code=500,
error_msg=f"Failed to download result from S3 URL: {s3_url}. Error: {str(e)}",
)
TResponse = TypeVar("TResponse")
class EndpointHandler(ABC, Generic[TResponse]):
def __init__(self, batch_client: _BaseForgeBatchClient):
self._batch_client = batch_client
self.min_retry_wait = batch_client.min_retry_wait
self.max_retry_wait = batch_client.max_retry_wait
self.max_retry_attempts = batch_client.max_retry_attempts
@property
@abstractmethod
def endpoint_name(self) -> str:
pass
@abstractmethod
def _prepare_request(self, **kwargs) -> list[dict[str, Any]]:
pass
@abstractmethod
def _process_response(self, response: dict, **kwargs) -> TResponse:
pass
@abstractmethod
async def _async_process_response(self, response: dict, **kwargs) -> TResponse:
pass
def run(
self,
timeout: int = 300,
disable_cache: bool = False,
cancel_on_timeout: bool = True,
**kwargs,
) -> TResponse | ESMProteinError:
"""
Submit and execute a batch job, waiting for completion by polling the status of the job.
Args:
timeout: Maximum time to wait for job completion, in seconds.
disable_cache: If True, bypasses any cached results and forces
a fresh computation.
cancel_on_timeout: If True, cancels the batch job if it times out or is interrupted.
**kwargs: Arguments to pass to the batch job.
Returns:
The response from the batch job or an ESMProteinError if the job fails.
"""
task_id = None
task_timed_out = False
keyboard_interrupted = False
try:
request = self._prepare_request(**kwargs)
task_id = self._batch_client.submit(
self.endpoint_name, request, disable_cache=disable_cache
)
response = self._batch_client.wait_for_completion(task_id, timeout)
return self._process_response(response, **kwargs)
except KeyboardInterrupt:
keyboard_interrupted = True
raise
except ESMProteinError as e:
if "timed out" in e.error_msg:
task_timed_out = True
return e
finally:
if (
cancel_on_timeout
and task_id
and (task_timed_out or keyboard_interrupted)
):
with suppress(
ESMProteinError
): # Don't surface errors from canceling the task
with suppress(KeyboardInterrupt):
self._batch_client.cancel(task_id)
async def async_run(
self,
timeout: int = 300,
disable_cache: bool = False,
cancel_on_timeout: bool = True,
**kwargs,
) -> TResponse | ESMProteinError:
task_id = None
task_timed_out = False
keyboard_interrupted = False
try:
request = self._prepare_request(**kwargs)
task_id = await self._batch_client.async_submit(
self.endpoint_name, request, disable_cache=disable_cache
)
response = await self._batch_client.async_wait_for_completion(
task_id, timeout
)
return await self._async_process_response(response, **kwargs)
except KeyboardInterrupt:
keyboard_interrupted = True
raise
except ESMProteinError as e:
if "timed out" in e.error_msg:
task_timed_out = True
return e
finally:
if (
cancel_on_timeout
and task_id
and (task_timed_out or keyboard_interrupted)
):
with suppress(
ESMProteinError
): # Don't surface errors from canceling the task
with suppress(KeyboardInterrupt):
await self._batch_client.async_cancel(task_id)

View File

@@ -21,12 +21,16 @@ from esm.sdk.api import (
InverseFoldingConfig,
LogitsConfig,
LogitsOutput,
ProteinChain,
ProteinComplex,
ProteinType,
SamplingConfig,
SamplingTrackConfig,
)
from esm.sdk.base_forge_client import _BaseForgeInferenceClient
from esm.sdk.base_forge_client import (
EndpointHandler,
_BaseForgeBatchClient,
_BaseForgeInferenceClient,
)
from esm.sdk.retry import retry_decorator
from esm.utils.constants.api import MIMETYPE_ES_PICKLE
from esm.utils.misc import deserialize_tensors, maybe_list, maybe_tensor
@@ -107,7 +111,7 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
sequence: str,
msa: MSA | Literal["auto"] | None,
config: FoldingConfig,
target_structure: ProteinChain | None,
target_structure: ProteinComplex | None,
model_name: str | None,
):
request: dict[str, Any] = {"sequence": sequence}
@@ -211,7 +215,7 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
model_name: str | None = None,
msa: MSA | Literal["auto"] | None = None,
config: FoldingConfig = FoldingConfig(),
target_structure: ProteinChain | None = None,
target_structure: ProteinComplex | None = None,
) -> ESMProtein | ESMProteinError:
"""Predict coordinates for a protein sequence.
@@ -253,7 +257,7 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
model_name: str | None = None,
msa: MSA | Literal["auto"] | None = None,
config: FoldingConfig = FoldingConfig(),
target_structure: ProteinChain | None = None,
target_structure: ProteinComplex | None = None,
) -> ESMProtein | ESMProteinError:
"""Predict coordinates for a protein sequence.
@@ -693,6 +697,7 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient, _BaseForgeInferenceClient):
sasa=maybe_tensor(data["outputs"]["sasa"]),
function=maybe_tensor(data["outputs"]["function"]),
residue_annotations=maybe_tensor(data["outputs"]["residue_annotation"]),
potential_sequence_of_concern=data["potential_sequence_of_concern"],
)
@staticmethod
@@ -1174,7 +1179,10 @@ class ESMCForgeInferenceClient(ESMCInferenceClient, _BaseForgeInferenceClient):
except ESMProteinError as e:
return e
return ESMProteinTensor(sequence=maybe_tensor(data["outputs"]["sequence"]))
return ESMProteinTensor(
sequence=maybe_tensor(data["outputs"]["sequence"]),
potential_sequence_of_concern=data["potential_sequence_of_concern"],
)
@retry_decorator
def encode(self, input: ESMProtein) -> ESMProteinTensor | ESMProteinError:
@@ -1188,7 +1196,10 @@ class ESMCForgeInferenceClient(ESMCInferenceClient, _BaseForgeInferenceClient):
except ESMProteinError as e:
return e
return ESMProteinTensor(sequence=maybe_tensor(data["outputs"]["sequence"]))
return ESMProteinTensor(
sequence=maybe_tensor(data["outputs"]["sequence"]),
potential_sequence_of_concern=data["potential_sequence_of_concern"],
)
@retry_decorator
async def async_decode(
@@ -1277,3 +1288,146 @@ class ESMCForgeInferenceClient(ESMCInferenceClient, _BaseForgeInferenceClient):
raise NotImplementedError(
f"Can not get underlying remote model {self.model} from a Forge client."
)
class FoldHandler(EndpointHandler[MolecularComplexResult]):
def __init__(self, batch_client: _BaseForgeBatchClient, model: str | None = None):
super().__init__(batch_client)
self.model = model
@property
def endpoint_name(self) -> str:
return "fold"
def _prepare_request(
self,
sequence: str,
model_name: str | None = None,
config: FoldingConfig = FoldingConfig(),
msa: MSA | Literal["auto"] | None = None,
target_structure: ProteinComplex | None = None,
) -> list[dict[str, Any]]:
assert (
config.include_distogram is False
), "include_distogram is not supported for Forge right now"
seq_len = len(sequence)
if seq_len == 0:
raise ValueError(
"Input sequence length is 0. Please provide a valid input."
)
request = SequenceStructureForgeInferenceClient._process_fold_request(
sequence,
msa,
config,
target_structure,
model_name if model_name is not None else self.model,
)
# batch API expects a list of requests, and currently we only support one sequence at a time
return [request]
def _process_response(self, response: dict, **kwargs) -> ESMProtein:
s3_url = response["response"]
result = self._batch_client.get_result_from_s3(s3_url)
return SequenceStructureForgeInferenceClient._process_fold_response(
result, kwargs["sequence"]
)
async def _async_process_response(self, response: dict, **kwargs) -> ESMProtein:
s3_url = response["response"]
result = await self._batch_client.async_get_result_from_s3(s3_url)
return SequenceStructureForgeInferenceClient._process_fold_response(
result, kwargs["sequence"]
)
class FoldAllAtomHandler(EndpointHandler[MolecularComplexResult]):
def __init__(self, batch_client: _BaseForgeBatchClient, model: str | None = None):
super().__init__(batch_client)
self.model = model
@property
def endpoint_name(self) -> str:
return "fold_all_atom"
def _prepare_request(
self,
all_atom_input: StructurePredictionInput,
model_name: str | None = None,
config: FoldingConfig = FoldingConfig(),
) -> list[dict[str, Any]]:
assert (
config.include_distogram is False
), "include_distogram is not supported for Forge right now"
if len(all_atom_input.sequences) == 0:
raise ValueError(
"Input sequence length is 0. Please provide a valid input."
)
request = SequenceStructureForgeInferenceClient._process_fold_all_atom_request(
all_atom_input, config, model_name if model_name is not None else self.model
)
# batch API expects a list of requests
return [request]
def _process_response(self, response: dict, **kwargs) -> MolecularComplexResult:
s3_url = response["response"]
result = self._batch_client.get_result_from_s3(s3_url)
# Use the same logic as _process_fold_all_atom_response
return SequenceStructureForgeInferenceClient._process_fold_all_atom_response(
result
)
async def _async_process_response(
self, response: dict, **kwargs
) -> MolecularComplexResult:
s3_url = response["response"]
result = await self._batch_client.async_get_result_from_s3(s3_url)
# Use the same logic as _process_fold_all_atom_response
return SequenceStructureForgeInferenceClient._process_fold_all_atom_response(
result
)
class ForgeBatchClient:
def __init__(
self,
url: str = "https://forge.evolutionaryscale.ai",
token: str = "",
request_timeout: int | None = None,
model: str | None = None,
poll_interval: int = 2,
min_retry_wait: int = 2,
max_retry_wait: int = 2,
max_retry_attempts: int = 5,
):
self._batch_client = _BaseForgeBatchClient(
url,
token,
request_timeout,
min_retry_wait,
max_retry_wait,
max_retry_attempts,
poll_interval,
)
self.model = model
self._fold: FoldHandler | None = None
self._fold_all_atom: FoldAllAtomHandler | None = None
# Add other handlers here
@property
def fold(self) -> FoldHandler:
if self._fold is None:
self._fold = FoldHandler(self._batch_client, self.model)
return self._fold
@property
def fold_all_atom(self) -> FoldAllAtomHandler:
if self._fold_all_atom is None:
self._fold_all_atom = FoldAllAtomHandler(self._batch_client, self.model)
return self._fold_all_atom
# Add other handlers here

View File

@@ -284,6 +284,9 @@ def sample_sasa_logits(
sasa_value[max_prob_idx == 18] = float("inf")
sasa_value[~sampling_mask] = float("inf")
# Set BOS and EOS tokens to 0
sasa_value[..., 0] = 0.0
sasa_value[..., -1] = 0.0
return sasa_value

View File

@@ -118,7 +118,6 @@ def serialize_structure_prediction_input(all_atom_input: StructurePredictionInpu
result: dict[str, Any] = {"sequences": sequences}
# Add covalent bonds if present
if all_atom_input.covalent_bonds is not None:
result["covalent_bonds"] = [
{
@@ -132,4 +131,16 @@ def serialize_structure_prediction_input(all_atom_input: StructurePredictionInpu
for bond in all_atom_input.covalent_bonds
]
if all_atom_input.pocket is not None:
result["pocket"] = {
"binder_chain_id": all_atom_input.pocket.binder_chain_id,
"contacts": all_atom_input.pocket.contacts,
}
if all_atom_input.distogram_conditioning is not None:
result["distogram_conditioning"] = [
{"chain_id": disto.chain_id, "distogram": disto.distogram.tolist()}
for disto in all_atom_input.distogram_conditioning
]
return result

View File

@@ -38,6 +38,20 @@ msgpack_numpy.patch()
CHAIN_ID_CONST = "A"
def _str_key_to_int_key(dct: dict, ignore_keys: list[str] | None = None) -> dict:
new_dict = {}
for k, v in dct.items():
v_new = v
if k not in ignore_keys and isinstance(v, dict):
v_new = _str_key_to_int_key(v, ignore_keys=ignore_keys)
# Note assembly_composition is *supposed* to have string keys.
if isinstance(k, str) and k.isdigit():
new_dict[int(k)] = v_new
else:
new_dict[k] = v_new
return new_dict
def _num_non_null_residues(seqres_to_structure_chain: Mapping[int, Residue]) -> int:
return sum(
residue.residue_number is not None
@@ -366,6 +380,9 @@ class ProteinChain:
@classmethod
def from_state_dict(cls, dct):
# Note: assembly_composition is *supposed* to have string keys.
dct = _str_key_to_int_key(dct, ignore_keys=["assembly_composition"])
for k, v in dct.items():
if isinstance(v, list):
dct[k] = np.array(v)
@@ -1121,7 +1138,9 @@ class ProteinChain:
def infer_oxygen(self) -> ProteinChain:
"""Oxygen position is fixed given N, CA, C atoms. Infer it if not provided."""
O_missing_indices = np.argwhere(np.isnan(self.atoms["O"]).any(axis=1)).squeeze()
O_missing_indices = np.argwhere(
~np.isfinite(self.atoms["O"]).all(axis=1)
).squeeze()
O_vector = torch.tensor([0.6240, -1.0613, 0.0103], dtype=torch.float32)
N, CA, C = torch.from_numpy(self.atoms[["N", "CA", "C"]]).float().unbind(dim=1)

View File

@@ -36,6 +36,7 @@ from esm.utils.structure.metrics import compute_gdt_ts, compute_lddt_ca
from esm.utils.structure.mmcif_parsing import MmcifWrapper, NoProteinError
from esm.utils.structure.protein_chain import (
ProteinChain,
_str_key_to_int_key,
chain_to_ndarray,
index_by_atom_name,
infer_CB,
@@ -410,6 +411,9 @@ class ProteinComplex:
@classmethod
def from_state_dict(cls, dct):
# Note: assembly_composition is *supposed* to have string keys.
dct = _str_key_to_int_key(dct, ignore_keys=["assembly_composition"])
for k, v in dct.items():
if isinstance(v, list):
dct[k] = np.array(v)
@@ -562,7 +566,9 @@ class ProteinComplex:
def infer_oxygen(self) -> ProteinComplex:
"""Oxygen position is fixed given N, CA, C atoms. Infer it if not provided."""
O_missing_indices = np.argwhere(np.isnan(self.atoms["O"]).any(axis=1)).squeeze()
O_missing_indices = np.argwhere(
~np.isfinite(self.atoms["O"]).all(axis=1)
).squeeze()
O_vector = torch.tensor([0.6240, -1.0613, 0.0103], dtype=torch.float32)
N, CA, C = torch.from_numpy(self.atoms[["N", "CA", "C"]]).float().unbind(dim=1)

View File

@@ -1726,13 +1726,13 @@ packages:
requires_python: '>=3.8'
- pypi: ./
name: esm
version: 3.2.3
sha256: 7f3df1026fb23f4812615d3c4968f643f04d9cbf7735000615b011620ac83007
version: 3.2.4
sha256: fd772451bd64ae146f55638f65e42482fab780805d4fde3c81087a44c36620e3
requires_dist:
- torch>=2.2.0
- torchvision
- torchtext
- transformers<4.48.2
- transformers<4.53.0
- ipython
- einops
- biotite>=1.0.0

View File

@@ -1,10 +1,10 @@
[project]
name = "esm"
version = "3.2.3"
version = "3.2.4"
description = "EvolutionaryScale open model repository"
readme = "README.md"
requires-python = ">=3.12,<3.13"
license = {file = "LICENSE.txt"}
license = {file = "LICENSE.md"}
authors = [
{name = "EvolutionaryScale Team"}
@@ -24,7 +24,7 @@ dependencies = [
"torch>=2.2.0",
"torchvision",
"torchtext",
"transformers<4.48.2",
"transformers<4.53.0",
"ipython",
"einops",
"biotite>=1.0.0",