Add logits example (#77)

This commit is contained in:
Jun Gong
2024-08-06 11:23:38 -07:00
committed by GitHub
parent 542a71d2d1
commit aaabedcf58
2 changed files with 24 additions and 1 deletions

View File

@@ -514,9 +514,16 @@ class ESM3(nn.Module, ESM3InferenceClient):
)
def logits(
self, input: ESMProteinTensor, config: LogitsConfig = LogitsConfig()
self,
input: ESMProteinTensor | _BatchedESMProteinTensor,
config: LogitsConfig = LogitsConfig(),
) -> LogitsOutput:
if not isinstance(input, _BatchedESMProteinTensor):
# Create batch dimension if necessary.
input = _BatchedESMProteinTensor.from_protein_tensor(input)
device = torch.device(input.device)
# Default plddt conditioning for inference. 1s where coordinates are provided.
if input.coordinates is None:
per_res_plddt = None

View File

@@ -5,6 +5,8 @@ from esm.sdk.api import (
ESMProteinError,
ESMProteinTensor,
GenerationConfig,
LogitsConfig,
LogitsOutput,
SamplingConfig,
SamplingTrackConfig,
)
@@ -87,6 +89,20 @@ def main(client: ESM3InferenceClient):
protein_with_function, ESMProtein
), f"{protein_with_function} is not an ESMProtein"
# Logits
protein = get_sample_protein()
protein.coordinates = None
protein.function_annotations = None
protein.sasa = None
protein_tensor = client.encode(protein)
logits_output = client.logits(protein_tensor, LogitsConfig(sequence=True))
assert isinstance(
logits_output, LogitsOutput
), f"LogitsOutput was expected but got {logits_output}"
assert (
logits_output.logits is not None and logits_output.logits.sequence is not None
)
# Chain of Thought (Function -> Secondary Structure -> Structure -> Sequence)
cot_protein = get_sample_protein()
cot_protein.sequence = "_" * len(cot_protein.sequence) # type: ignore