mirror of
https://github.com/evolutionaryscale/esm.git
synced 2026-06-04 17:14:23 +08:00
Add logits example (#77)
This commit is contained in:
@@ -514,9 +514,16 @@ class ESM3(nn.Module, ESM3InferenceClient):
|
||||
)
|
||||
|
||||
def logits(
|
||||
self, input: ESMProteinTensor, config: LogitsConfig = LogitsConfig()
|
||||
self,
|
||||
input: ESMProteinTensor | _BatchedESMProteinTensor,
|
||||
config: LogitsConfig = LogitsConfig(),
|
||||
) -> LogitsOutput:
|
||||
if not isinstance(input, _BatchedESMProteinTensor):
|
||||
# Create batch dimension if necessary.
|
||||
input = _BatchedESMProteinTensor.from_protein_tensor(input)
|
||||
|
||||
device = torch.device(input.device)
|
||||
|
||||
# Default plddt conditioning for inference. 1s where coordinates are provided.
|
||||
if input.coordinates is None:
|
||||
per_res_plddt = None
|
||||
|
||||
@@ -5,6 +5,8 @@ from esm.sdk.api import (
|
||||
ESMProteinError,
|
||||
ESMProteinTensor,
|
||||
GenerationConfig,
|
||||
LogitsConfig,
|
||||
LogitsOutput,
|
||||
SamplingConfig,
|
||||
SamplingTrackConfig,
|
||||
)
|
||||
@@ -87,6 +89,20 @@ def main(client: ESM3InferenceClient):
|
||||
protein_with_function, ESMProtein
|
||||
), f"{protein_with_function} is not an ESMProtein"
|
||||
|
||||
# Logits
|
||||
protein = get_sample_protein()
|
||||
protein.coordinates = None
|
||||
protein.function_annotations = None
|
||||
protein.sasa = None
|
||||
protein_tensor = client.encode(protein)
|
||||
logits_output = client.logits(protein_tensor, LogitsConfig(sequence=True))
|
||||
assert isinstance(
|
||||
logits_output, LogitsOutput
|
||||
), f"LogitsOutput was expected but got {logits_output}"
|
||||
assert (
|
||||
logits_output.logits is not None and logits_output.logits.sequence is not None
|
||||
)
|
||||
|
||||
# Chain of Thought (Function -> Secondary Structure -> Structure -> Sequence)
|
||||
cot_protein = get_sample_protein()
|
||||
cot_protein.sequence = "_" * len(cot_protein.sequence) # type: ignore
|
||||
|
||||
Reference in New Issue
Block a user