mirror of
https://github.com/evolutionaryscale/esm.git
synced 2026-06-04 17:14:23 +08:00
allow seqlen > 800
This commit is contained in:
@@ -1 +1 @@
|
||||
__version__ = "3.2.3"
|
||||
__version__ = "3.2.4"
|
||||
|
||||
@@ -489,6 +489,7 @@ class ESM3(nn.Module, ESM3InferenceClient):
|
||||
function=function_tokens,
|
||||
residue_annotations=residue_annotation_tokens,
|
||||
coordinates=coordinates,
|
||||
potential_sequence_of_concern=input.potential_sequence_of_concern,
|
||||
).to(next(self.parameters()).device)
|
||||
|
||||
def decode(self, input: ESMProteinTensor) -> ESMProtein:
|
||||
|
||||
@@ -178,9 +178,10 @@ class ESMC(nn.Module, ESMCInferenceClient):
|
||||
|
||||
if input.sequence is not None:
|
||||
sequence_tokens = self._tokenize([input.sequence])[0]
|
||||
return ESMProteinTensor(sequence=sequence_tokens).to(
|
||||
next(self.parameters()).device
|
||||
)
|
||||
return ESMProteinTensor(
|
||||
sequence=sequence_tokens,
|
||||
potential_sequence_of_concern=input.potential_sequence_of_concern,
|
||||
).to(next(self.parameters()).device)
|
||||
|
||||
def decode(self, input: ESMProteinTensor) -> ESMProtein:
|
||||
input = attr.evolve(input) # Make a copy
|
||||
|
||||
@@ -87,6 +87,7 @@ class ESMProtein(ProteinType):
|
||||
sasa=protein_chain.sasa().tolist(),
|
||||
function_annotations=None,
|
||||
coordinates=torch.tensor(protein_chain.atom37_positions),
|
||||
plddt=torch.tensor(protein_chain.confidence),
|
||||
)
|
||||
else:
|
||||
return ESMProtein(
|
||||
@@ -95,6 +96,7 @@ class ESMProtein(ProteinType):
|
||||
sasa=None,
|
||||
function_annotations=None,
|
||||
coordinates=torch.tensor(protein_chain.atom37_positions),
|
||||
plddt=torch.tensor(protein_chain.confidence),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -114,6 +116,7 @@ class ESMProtein(ProteinType):
|
||||
coordinates=torch.tensor(
|
||||
protein_complex.atom37_positions, dtype=torch.float32
|
||||
),
|
||||
plddt=torch.tensor(protein_complex.confidence),
|
||||
)
|
||||
|
||||
def to_pdb(self, pdb_path: PathOrBuffer) -> None:
|
||||
@@ -194,6 +197,9 @@ class ESMProtein(ProteinType):
|
||||
chain_id=gt_chains[i].chain_id
|
||||
if gt_chains is not None
|
||||
else SINGLE_LETTER_CHAIN_IDS[i],
|
||||
residue_index=self.residue_index[start:end]
|
||||
if self.residue_index is not None
|
||||
else None,
|
||||
entity_id=gt_chains[i].entity_id if gt_chains is not None else None,
|
||||
confidence=plddt[start:end] if plddt is not None else None,
|
||||
)
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
from typing import Any
|
||||
import asyncio
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import suppress
|
||||
from typing import Any, Generic, TypeVar
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
|
||||
from esm.sdk.api import ESMProteinError
|
||||
from esm.sdk.retry import retry_decorator
|
||||
from esm.utils.decoding import assemble_message
|
||||
|
||||
|
||||
@@ -72,6 +77,7 @@ class _BaseForgeInferenceClient:
|
||||
request: dict[str, Any],
|
||||
potential_sequence_of_concern: bool | None = None,
|
||||
return_bytes: bool = False,
|
||||
disable_cache: bool = False,
|
||||
headers: dict[str, str] = {},
|
||||
) -> tuple[dict[str, Any], dict[str, str]]:
|
||||
if potential_sequence_of_concern is not None:
|
||||
@@ -80,6 +86,8 @@ class _BaseForgeInferenceClient:
|
||||
headers = {**self.headers, **headers}
|
||||
if return_bytes:
|
||||
headers["return-bytes"] = "true"
|
||||
if disable_cache:
|
||||
headers["X-Disable-Cache"] = "true"
|
||||
return request, headers
|
||||
|
||||
def prepare_data(self, response, endpoint: str) -> dict[str, Any]:
|
||||
@@ -108,11 +116,16 @@ class _BaseForgeInferenceClient:
|
||||
potential_sequence_of_concern: bool | None = None,
|
||||
params: dict[str, Any] = {},
|
||||
headers: dict[str, str] = {},
|
||||
disable_cache: bool = False,
|
||||
return_bytes: bool = False,
|
||||
):
|
||||
try:
|
||||
request, headers = self.prepare_request(
|
||||
request, potential_sequence_of_concern, return_bytes, headers
|
||||
request,
|
||||
potential_sequence_of_concern,
|
||||
return_bytes,
|
||||
disable_cache,
|
||||
headers,
|
||||
)
|
||||
response = await self.async_client.post(
|
||||
url=urljoin(self.url, f"/api/v1/{endpoint}"),
|
||||
@@ -139,10 +152,15 @@ class _BaseForgeInferenceClient:
|
||||
params: dict[str, Any] = {},
|
||||
headers: dict[str, str] = {},
|
||||
return_bytes: bool = False,
|
||||
disable_cache: bool = False,
|
||||
):
|
||||
try:
|
||||
request, headers = self.prepare_request(
|
||||
request, potential_sequence_of_concern, return_bytes, headers
|
||||
request,
|
||||
potential_sequence_of_concern,
|
||||
return_bytes,
|
||||
disable_cache,
|
||||
headers,
|
||||
)
|
||||
response = self.client.post(
|
||||
url=urljoin(self.url, f"/api/v1/{endpoint}"),
|
||||
@@ -160,3 +178,264 @@ class _BaseForgeInferenceClient:
|
||||
error_code=500,
|
||||
error_msg=f"Failed to submit request to {endpoint}. Error: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
class _BaseForgeBatchClient(_BaseForgeInferenceClient):
|
||||
"""
|
||||
A Python client for the protein folding batch API.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str = "https://forge.evolutionaryscale.ai",
|
||||
token: str = "",
|
||||
request_timeout: int | None = None,
|
||||
min_retry_wait: int = 1,
|
||||
max_retry_wait: int = 10,
|
||||
max_retry_attempts: int = 5,
|
||||
poll_interval: int = 2,
|
||||
):
|
||||
super().__init__(
|
||||
model="", # model is not used in batch client
|
||||
url=url,
|
||||
token=token,
|
||||
request_timeout=request_timeout,
|
||||
min_retry_wait=min_retry_wait,
|
||||
max_retry_wait=max_retry_wait,
|
||||
max_retry_attempts=max_retry_attempts,
|
||||
)
|
||||
# How often to poll for status
|
||||
self.poll_interval = poll_interval
|
||||
|
||||
@retry_decorator
|
||||
def submit(
|
||||
self, endpoint: str, payload: list[dict[str, Any]], disable_cache: bool = False
|
||||
) -> str:
|
||||
response_data = self._post(
|
||||
"batch/submit",
|
||||
{"endpoint": endpoint, "payload": payload},
|
||||
disable_cache=disable_cache,
|
||||
)
|
||||
task_id = response_data.get("task_id")
|
||||
if not task_id:
|
||||
raise ESMProteinError(
|
||||
error_code=500, error_msg="API did not return a valid task_id."
|
||||
)
|
||||
return task_id
|
||||
|
||||
@retry_decorator
|
||||
async def async_submit(
|
||||
self, endpoint: str, payload: list[dict[str, Any]], disable_cache: bool = False
|
||||
) -> str:
|
||||
response_data = await self._async_post(
|
||||
"batch/submit",
|
||||
{"endpoint": endpoint, "payload": payload},
|
||||
disable_cache=disable_cache,
|
||||
)
|
||||
task_id = response_data.get("task_id")
|
||||
if not task_id:
|
||||
raise ESMProteinError(
|
||||
error_code=500, error_msg="API did not return a valid task_id."
|
||||
)
|
||||
return task_id
|
||||
|
||||
def cancel(self, task_id: str) -> dict[str, Any]:
|
||||
return self._post("batch/cancel", {"task_id": task_id})
|
||||
|
||||
async def async_cancel(self, task_id: str) -> dict[str, Any]:
|
||||
return await self._async_post("batch/cancel", {"task_id": task_id})
|
||||
|
||||
@retry_decorator
|
||||
def get_status(self, task_id: str) -> dict[str, Any]:
|
||||
return self._post("batch/status", {"task_id": task_id})
|
||||
|
||||
@retry_decorator
|
||||
async def async_get_status(self, task_id: str) -> dict[str, Any]:
|
||||
return await self._async_post("batch/status", {"task_id": task_id})
|
||||
|
||||
def wait_for_completion(self, task_id: str, timeout: int) -> dict:
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
response = self.get_status(task_id)
|
||||
job_status = response.get("status")
|
||||
if job_status == "done":
|
||||
return response
|
||||
elif job_status == "cancelled":
|
||||
raise ESMProteinError(
|
||||
error_code=500, error_msg=f"Job {task_id} cancelled."
|
||||
)
|
||||
elif job_status == "failed":
|
||||
raise ESMProteinError(
|
||||
error_code=500,
|
||||
error_msg=f"Job {task_id} failed with error: '{response.get('error')}'.",
|
||||
)
|
||||
time.sleep(self.poll_interval)
|
||||
|
||||
raise ESMProteinError(
|
||||
error_code=500,
|
||||
error_msg=f"Job {task_id} timed out after {timeout} seconds.",
|
||||
)
|
||||
|
||||
async def async_wait_for_completion(self, task_id: str, timeout: int) -> dict:
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
response = await self.async_get_status(task_id)
|
||||
job_status = response.get("status")
|
||||
if job_status == "done":
|
||||
return response
|
||||
elif job_status == "cancelled":
|
||||
raise ESMProteinError(
|
||||
error_code=500, error_msg=f"Job {task_id} cancelled."
|
||||
)
|
||||
elif job_status == "failed":
|
||||
raise ESMProteinError(
|
||||
error_code=500,
|
||||
error_msg=f"Job {task_id} failed with error: '{response.get('error')}'.",
|
||||
)
|
||||
await asyncio.sleep(self.poll_interval)
|
||||
|
||||
raise ESMProteinError(
|
||||
error_code=500,
|
||||
error_msg=f"Job {task_id} timed out after {timeout} seconds.",
|
||||
)
|
||||
|
||||
@retry_decorator
|
||||
def get_result_from_s3(self, s3_url: str) -> dict[str, Any]:
|
||||
"""Downloads the result JSON from a pre-signed S3 URL."""
|
||||
try:
|
||||
response = self.client.get(s3_url)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
raise ESMProteinError(
|
||||
error_code=500,
|
||||
error_msg=f"Failed to download result from S3 URL: {s3_url}. Error: {str(e)}",
|
||||
)
|
||||
|
||||
@retry_decorator
|
||||
async def async_get_result_from_s3(self, s3_url: str) -> dict[str, Any]:
|
||||
"""Asynchronously downloads the result JSON from a pre-signed S3 URL."""
|
||||
try:
|
||||
response = await self.async_client.get(s3_url)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
raise ESMProteinError(
|
||||
error_code=500,
|
||||
error_msg=f"Failed to download result from S3 URL: {s3_url}. Error: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
TResponse = TypeVar("TResponse")
|
||||
|
||||
|
||||
class EndpointHandler(ABC, Generic[TResponse]):
|
||||
def __init__(self, batch_client: _BaseForgeBatchClient):
|
||||
self._batch_client = batch_client
|
||||
self.min_retry_wait = batch_client.min_retry_wait
|
||||
self.max_retry_wait = batch_client.max_retry_wait
|
||||
self.max_retry_attempts = batch_client.max_retry_attempts
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def endpoint_name(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _prepare_request(self, **kwargs) -> list[dict[str, Any]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _process_response(self, response: dict, **kwargs) -> TResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _async_process_response(self, response: dict, **kwargs) -> TResponse:
|
||||
pass
|
||||
|
||||
def run(
|
||||
self,
|
||||
timeout: int = 300,
|
||||
disable_cache: bool = False,
|
||||
cancel_on_timeout: bool = True,
|
||||
**kwargs,
|
||||
) -> TResponse | ESMProteinError:
|
||||
"""
|
||||
Submit and execute a batch job, waiting for completion by polling the status of the job.
|
||||
Args:
|
||||
timeout: Maximum time to wait for job completion, in seconds.
|
||||
disable_cache: If True, bypasses any cached results and forces
|
||||
a fresh computation.
|
||||
cancel_on_timeout: If True, cancels the batch job if it times out or is interrupted.
|
||||
**kwargs: Arguments to pass to the batch job.
|
||||
Returns:
|
||||
The response from the batch job or an ESMProteinError if the job fails.
|
||||
"""
|
||||
task_id = None
|
||||
task_timed_out = False
|
||||
keyboard_interrupted = False
|
||||
try:
|
||||
request = self._prepare_request(**kwargs)
|
||||
task_id = self._batch_client.submit(
|
||||
self.endpoint_name, request, disable_cache=disable_cache
|
||||
)
|
||||
response = self._batch_client.wait_for_completion(task_id, timeout)
|
||||
return self._process_response(response, **kwargs)
|
||||
except KeyboardInterrupt:
|
||||
keyboard_interrupted = True
|
||||
raise
|
||||
except ESMProteinError as e:
|
||||
if "timed out" in e.error_msg:
|
||||
task_timed_out = True
|
||||
return e
|
||||
finally:
|
||||
if (
|
||||
cancel_on_timeout
|
||||
and task_id
|
||||
and (task_timed_out or keyboard_interrupted)
|
||||
):
|
||||
with suppress(
|
||||
ESMProteinError
|
||||
): # Don't surface errors from canceling the task
|
||||
with suppress(KeyboardInterrupt):
|
||||
self._batch_client.cancel(task_id)
|
||||
|
||||
async def async_run(
|
||||
self,
|
||||
timeout: int = 300,
|
||||
disable_cache: bool = False,
|
||||
cancel_on_timeout: bool = True,
|
||||
**kwargs,
|
||||
) -> TResponse | ESMProteinError:
|
||||
task_id = None
|
||||
task_timed_out = False
|
||||
keyboard_interrupted = False
|
||||
try:
|
||||
request = self._prepare_request(**kwargs)
|
||||
task_id = await self._batch_client.async_submit(
|
||||
self.endpoint_name, request, disable_cache=disable_cache
|
||||
)
|
||||
response = await self._batch_client.async_wait_for_completion(
|
||||
task_id, timeout
|
||||
)
|
||||
return await self._async_process_response(response, **kwargs)
|
||||
except KeyboardInterrupt:
|
||||
keyboard_interrupted = True
|
||||
raise
|
||||
except ESMProteinError as e:
|
||||
if "timed out" in e.error_msg:
|
||||
task_timed_out = True
|
||||
return e
|
||||
finally:
|
||||
if (
|
||||
cancel_on_timeout
|
||||
and task_id
|
||||
and (task_timed_out or keyboard_interrupted)
|
||||
):
|
||||
with suppress(
|
||||
ESMProteinError
|
||||
): # Don't surface errors from canceling the task
|
||||
with suppress(KeyboardInterrupt):
|
||||
await self._batch_client.async_cancel(task_id)
|
||||
|
||||
168
esm/sdk/forge.py
168
esm/sdk/forge.py
@@ -21,12 +21,16 @@ from esm.sdk.api import (
|
||||
InverseFoldingConfig,
|
||||
LogitsConfig,
|
||||
LogitsOutput,
|
||||
ProteinChain,
|
||||
ProteinComplex,
|
||||
ProteinType,
|
||||
SamplingConfig,
|
||||
SamplingTrackConfig,
|
||||
)
|
||||
from esm.sdk.base_forge_client import _BaseForgeInferenceClient
|
||||
from esm.sdk.base_forge_client import (
|
||||
EndpointHandler,
|
||||
_BaseForgeBatchClient,
|
||||
_BaseForgeInferenceClient,
|
||||
)
|
||||
from esm.sdk.retry import retry_decorator
|
||||
from esm.utils.constants.api import MIMETYPE_ES_PICKLE
|
||||
from esm.utils.misc import deserialize_tensors, maybe_list, maybe_tensor
|
||||
@@ -107,7 +111,7 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
||||
sequence: str,
|
||||
msa: MSA | Literal["auto"] | None,
|
||||
config: FoldingConfig,
|
||||
target_structure: ProteinChain | None,
|
||||
target_structure: ProteinComplex | None,
|
||||
model_name: str | None,
|
||||
):
|
||||
request: dict[str, Any] = {"sequence": sequence}
|
||||
@@ -211,7 +215,7 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
||||
model_name: str | None = None,
|
||||
msa: MSA | Literal["auto"] | None = None,
|
||||
config: FoldingConfig = FoldingConfig(),
|
||||
target_structure: ProteinChain | None = None,
|
||||
target_structure: ProteinComplex | None = None,
|
||||
) -> ESMProtein | ESMProteinError:
|
||||
"""Predict coordinates for a protein sequence.
|
||||
|
||||
@@ -253,7 +257,7 @@ class SequenceStructureForgeInferenceClient(_BaseForgeInferenceClient):
|
||||
model_name: str | None = None,
|
||||
msa: MSA | Literal["auto"] | None = None,
|
||||
config: FoldingConfig = FoldingConfig(),
|
||||
target_structure: ProteinChain | None = None,
|
||||
target_structure: ProteinComplex | None = None,
|
||||
) -> ESMProtein | ESMProteinError:
|
||||
"""Predict coordinates for a protein sequence.
|
||||
|
||||
@@ -693,6 +697,7 @@ class ESM3ForgeInferenceClient(ESM3InferenceClient, _BaseForgeInferenceClient):
|
||||
sasa=maybe_tensor(data["outputs"]["sasa"]),
|
||||
function=maybe_tensor(data["outputs"]["function"]),
|
||||
residue_annotations=maybe_tensor(data["outputs"]["residue_annotation"]),
|
||||
potential_sequence_of_concern=data["potential_sequence_of_concern"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -1174,7 +1179,10 @@ class ESMCForgeInferenceClient(ESMCInferenceClient, _BaseForgeInferenceClient):
|
||||
except ESMProteinError as e:
|
||||
return e
|
||||
|
||||
return ESMProteinTensor(sequence=maybe_tensor(data["outputs"]["sequence"]))
|
||||
return ESMProteinTensor(
|
||||
sequence=maybe_tensor(data["outputs"]["sequence"]),
|
||||
potential_sequence_of_concern=data["potential_sequence_of_concern"],
|
||||
)
|
||||
|
||||
@retry_decorator
|
||||
def encode(self, input: ESMProtein) -> ESMProteinTensor | ESMProteinError:
|
||||
@@ -1188,7 +1196,10 @@ class ESMCForgeInferenceClient(ESMCInferenceClient, _BaseForgeInferenceClient):
|
||||
except ESMProteinError as e:
|
||||
return e
|
||||
|
||||
return ESMProteinTensor(sequence=maybe_tensor(data["outputs"]["sequence"]))
|
||||
return ESMProteinTensor(
|
||||
sequence=maybe_tensor(data["outputs"]["sequence"]),
|
||||
potential_sequence_of_concern=data["potential_sequence_of_concern"],
|
||||
)
|
||||
|
||||
@retry_decorator
|
||||
async def async_decode(
|
||||
@@ -1277,3 +1288,146 @@ class ESMCForgeInferenceClient(ESMCInferenceClient, _BaseForgeInferenceClient):
|
||||
raise NotImplementedError(
|
||||
f"Can not get underlying remote model {self.model} from a Forge client."
|
||||
)
|
||||
|
||||
|
||||
class FoldHandler(EndpointHandler[MolecularComplexResult]):
|
||||
def __init__(self, batch_client: _BaseForgeBatchClient, model: str | None = None):
|
||||
super().__init__(batch_client)
|
||||
self.model = model
|
||||
|
||||
@property
|
||||
def endpoint_name(self) -> str:
|
||||
return "fold"
|
||||
|
||||
def _prepare_request(
|
||||
self,
|
||||
sequence: str,
|
||||
model_name: str | None = None,
|
||||
config: FoldingConfig = FoldingConfig(),
|
||||
msa: MSA | Literal["auto"] | None = None,
|
||||
target_structure: ProteinComplex | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
assert (
|
||||
config.include_distogram is False
|
||||
), "include_distogram is not supported for Forge right now"
|
||||
|
||||
seq_len = len(sequence)
|
||||
if seq_len == 0:
|
||||
raise ValueError(
|
||||
"Input sequence length is 0. Please provide a valid input."
|
||||
)
|
||||
|
||||
request = SequenceStructureForgeInferenceClient._process_fold_request(
|
||||
sequence,
|
||||
msa,
|
||||
config,
|
||||
target_structure,
|
||||
model_name if model_name is not None else self.model,
|
||||
)
|
||||
# batch API expects a list of requests, and currently we only support one sequence at a time
|
||||
return [request]
|
||||
|
||||
def _process_response(self, response: dict, **kwargs) -> ESMProtein:
|
||||
s3_url = response["response"]
|
||||
result = self._batch_client.get_result_from_s3(s3_url)
|
||||
return SequenceStructureForgeInferenceClient._process_fold_response(
|
||||
result, kwargs["sequence"]
|
||||
)
|
||||
|
||||
async def _async_process_response(self, response: dict, **kwargs) -> ESMProtein:
|
||||
s3_url = response["response"]
|
||||
result = await self._batch_client.async_get_result_from_s3(s3_url)
|
||||
return SequenceStructureForgeInferenceClient._process_fold_response(
|
||||
result, kwargs["sequence"]
|
||||
)
|
||||
|
||||
|
||||
class FoldAllAtomHandler(EndpointHandler[MolecularComplexResult]):
|
||||
def __init__(self, batch_client: _BaseForgeBatchClient, model: str | None = None):
|
||||
super().__init__(batch_client)
|
||||
self.model = model
|
||||
|
||||
@property
|
||||
def endpoint_name(self) -> str:
|
||||
return "fold_all_atom"
|
||||
|
||||
def _prepare_request(
|
||||
self,
|
||||
all_atom_input: StructurePredictionInput,
|
||||
model_name: str | None = None,
|
||||
config: FoldingConfig = FoldingConfig(),
|
||||
) -> list[dict[str, Any]]:
|
||||
assert (
|
||||
config.include_distogram is False
|
||||
), "include_distogram is not supported for Forge right now"
|
||||
|
||||
if len(all_atom_input.sequences) == 0:
|
||||
raise ValueError(
|
||||
"Input sequence length is 0. Please provide a valid input."
|
||||
)
|
||||
|
||||
request = SequenceStructureForgeInferenceClient._process_fold_all_atom_request(
|
||||
all_atom_input, config, model_name if model_name is not None else self.model
|
||||
)
|
||||
# batch API expects a list of requests
|
||||
return [request]
|
||||
|
||||
def _process_response(self, response: dict, **kwargs) -> MolecularComplexResult:
|
||||
s3_url = response["response"]
|
||||
result = self._batch_client.get_result_from_s3(s3_url)
|
||||
# Use the same logic as _process_fold_all_atom_response
|
||||
return SequenceStructureForgeInferenceClient._process_fold_all_atom_response(
|
||||
result
|
||||
)
|
||||
|
||||
async def _async_process_response(
|
||||
self, response: dict, **kwargs
|
||||
) -> MolecularComplexResult:
|
||||
s3_url = response["response"]
|
||||
result = await self._batch_client.async_get_result_from_s3(s3_url)
|
||||
# Use the same logic as _process_fold_all_atom_response
|
||||
return SequenceStructureForgeInferenceClient._process_fold_all_atom_response(
|
||||
result
|
||||
)
|
||||
|
||||
|
||||
class ForgeBatchClient:
|
||||
def __init__(
|
||||
self,
|
||||
url: str = "https://forge.evolutionaryscale.ai",
|
||||
token: str = "",
|
||||
request_timeout: int | None = None,
|
||||
model: str | None = None,
|
||||
poll_interval: int = 2,
|
||||
min_retry_wait: int = 2,
|
||||
max_retry_wait: int = 2,
|
||||
max_retry_attempts: int = 5,
|
||||
):
|
||||
self._batch_client = _BaseForgeBatchClient(
|
||||
url,
|
||||
token,
|
||||
request_timeout,
|
||||
min_retry_wait,
|
||||
max_retry_wait,
|
||||
max_retry_attempts,
|
||||
poll_interval,
|
||||
)
|
||||
self.model = model
|
||||
|
||||
self._fold: FoldHandler | None = None
|
||||
self._fold_all_atom: FoldAllAtomHandler | None = None
|
||||
# Add other handlers here
|
||||
|
||||
@property
|
||||
def fold(self) -> FoldHandler:
|
||||
if self._fold is None:
|
||||
self._fold = FoldHandler(self._batch_client, self.model)
|
||||
return self._fold
|
||||
|
||||
@property
|
||||
def fold_all_atom(self) -> FoldAllAtomHandler:
|
||||
if self._fold_all_atom is None:
|
||||
self._fold_all_atom = FoldAllAtomHandler(self._batch_client, self.model)
|
||||
return self._fold_all_atom
|
||||
|
||||
# Add other handlers here
|
||||
|
||||
@@ -284,6 +284,9 @@ def sample_sasa_logits(
|
||||
sasa_value[max_prob_idx == 18] = float("inf")
|
||||
sasa_value[~sampling_mask] = float("inf")
|
||||
|
||||
# Set BOS and EOS tokens to 0
|
||||
sasa_value[..., 0] = 0.0
|
||||
sasa_value[..., -1] = 0.0
|
||||
return sasa_value
|
||||
|
||||
|
||||
|
||||
@@ -118,7 +118,6 @@ def serialize_structure_prediction_input(all_atom_input: StructurePredictionInpu
|
||||
|
||||
result: dict[str, Any] = {"sequences": sequences}
|
||||
|
||||
# Add covalent bonds if present
|
||||
if all_atom_input.covalent_bonds is not None:
|
||||
result["covalent_bonds"] = [
|
||||
{
|
||||
@@ -132,4 +131,16 @@ def serialize_structure_prediction_input(all_atom_input: StructurePredictionInpu
|
||||
for bond in all_atom_input.covalent_bonds
|
||||
]
|
||||
|
||||
if all_atom_input.pocket is not None:
|
||||
result["pocket"] = {
|
||||
"binder_chain_id": all_atom_input.pocket.binder_chain_id,
|
||||
"contacts": all_atom_input.pocket.contacts,
|
||||
}
|
||||
|
||||
if all_atom_input.distogram_conditioning is not None:
|
||||
result["distogram_conditioning"] = [
|
||||
{"chain_id": disto.chain_id, "distogram": disto.distogram.tolist()}
|
||||
for disto in all_atom_input.distogram_conditioning
|
||||
]
|
||||
|
||||
return result
|
||||
|
||||
@@ -38,6 +38,20 @@ msgpack_numpy.patch()
|
||||
CHAIN_ID_CONST = "A"
|
||||
|
||||
|
||||
def _str_key_to_int_key(dct: dict, ignore_keys: list[str] | None = None) -> dict:
|
||||
new_dict = {}
|
||||
for k, v in dct.items():
|
||||
v_new = v
|
||||
if k not in ignore_keys and isinstance(v, dict):
|
||||
v_new = _str_key_to_int_key(v, ignore_keys=ignore_keys)
|
||||
# Note assembly_composition is *supposed* to have string keys.
|
||||
if isinstance(k, str) and k.isdigit():
|
||||
new_dict[int(k)] = v_new
|
||||
else:
|
||||
new_dict[k] = v_new
|
||||
return new_dict
|
||||
|
||||
|
||||
def _num_non_null_residues(seqres_to_structure_chain: Mapping[int, Residue]) -> int:
|
||||
return sum(
|
||||
residue.residue_number is not None
|
||||
@@ -366,6 +380,9 @@ class ProteinChain:
|
||||
|
||||
@classmethod
|
||||
def from_state_dict(cls, dct):
|
||||
# Note: assembly_composition is *supposed* to have string keys.
|
||||
dct = _str_key_to_int_key(dct, ignore_keys=["assembly_composition"])
|
||||
|
||||
for k, v in dct.items():
|
||||
if isinstance(v, list):
|
||||
dct[k] = np.array(v)
|
||||
@@ -1121,7 +1138,9 @@ class ProteinChain:
|
||||
|
||||
def infer_oxygen(self) -> ProteinChain:
|
||||
"""Oxygen position is fixed given N, CA, C atoms. Infer it if not provided."""
|
||||
O_missing_indices = np.argwhere(np.isnan(self.atoms["O"]).any(axis=1)).squeeze()
|
||||
O_missing_indices = np.argwhere(
|
||||
~np.isfinite(self.atoms["O"]).all(axis=1)
|
||||
).squeeze()
|
||||
|
||||
O_vector = torch.tensor([0.6240, -1.0613, 0.0103], dtype=torch.float32)
|
||||
N, CA, C = torch.from_numpy(self.atoms[["N", "CA", "C"]]).float().unbind(dim=1)
|
||||
|
||||
@@ -36,6 +36,7 @@ from esm.utils.structure.metrics import compute_gdt_ts, compute_lddt_ca
|
||||
from esm.utils.structure.mmcif_parsing import MmcifWrapper, NoProteinError
|
||||
from esm.utils.structure.protein_chain import (
|
||||
ProteinChain,
|
||||
_str_key_to_int_key,
|
||||
chain_to_ndarray,
|
||||
index_by_atom_name,
|
||||
infer_CB,
|
||||
@@ -410,6 +411,9 @@ class ProteinComplex:
|
||||
|
||||
@classmethod
|
||||
def from_state_dict(cls, dct):
|
||||
# Note: assembly_composition is *supposed* to have string keys.
|
||||
dct = _str_key_to_int_key(dct, ignore_keys=["assembly_composition"])
|
||||
|
||||
for k, v in dct.items():
|
||||
if isinstance(v, list):
|
||||
dct[k] = np.array(v)
|
||||
@@ -562,7 +566,9 @@ class ProteinComplex:
|
||||
|
||||
def infer_oxygen(self) -> ProteinComplex:
|
||||
"""Oxygen position is fixed given N, CA, C atoms. Infer it if not provided."""
|
||||
O_missing_indices = np.argwhere(np.isnan(self.atoms["O"]).any(axis=1)).squeeze()
|
||||
O_missing_indices = np.argwhere(
|
||||
~np.isfinite(self.atoms["O"]).all(axis=1)
|
||||
).squeeze()
|
||||
|
||||
O_vector = torch.tensor([0.6240, -1.0613, 0.0103], dtype=torch.float32)
|
||||
N, CA, C = torch.from_numpy(self.atoms[["N", "CA", "C"]]).float().unbind(dim=1)
|
||||
|
||||
@@ -1726,13 +1726,13 @@ packages:
|
||||
requires_python: '>=3.8'
|
||||
- pypi: ./
|
||||
name: esm
|
||||
version: 3.2.3
|
||||
sha256: 7f3df1026fb23f4812615d3c4968f643f04d9cbf7735000615b011620ac83007
|
||||
version: 3.2.4
|
||||
sha256: fd772451bd64ae146f55638f65e42482fab780805d4fde3c81087a44c36620e3
|
||||
requires_dist:
|
||||
- torch>=2.2.0
|
||||
- torchvision
|
||||
- torchtext
|
||||
- transformers<4.48.2
|
||||
- transformers<4.53.0
|
||||
- ipython
|
||||
- einops
|
||||
- biotite>=1.0.0
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
[project]
|
||||
name = "esm"
|
||||
version = "3.2.3"
|
||||
version = "3.2.4"
|
||||
description = "EvolutionaryScale open model repository"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12,<3.13"
|
||||
license = {file = "LICENSE.txt"}
|
||||
license = {file = "LICENSE.md"}
|
||||
|
||||
authors = [
|
||||
{name = "EvolutionaryScale Team"}
|
||||
@@ -24,7 +24,7 @@ dependencies = [
|
||||
"torch>=2.2.0",
|
||||
"torchvision",
|
||||
"torchtext",
|
||||
"transformers<4.48.2",
|
||||
"transformers<4.53.0",
|
||||
"ipython",
|
||||
"einops",
|
||||
"biotite>=1.0.0",
|
||||
|
||||
Reference in New Issue
Block a user