mirror of
https://github.com/aqlaboratory/openfold.git
synced 2026-06-04 12:44:26 +08:00
27 lines
1.2 KiB
Python
27 lines
1.2 KiB
Python
# import torch
|
|
#
|
|
# # Load ESM-1b model
|
|
# model, alphabet = torch.hub.load("facebookresearch/esm:main", "esm1b_t33_650M_UR50S")
|
|
# batch_converter = alphabet.get_batch_converter()
|
|
#
|
|
# # Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
|
|
# data = [
|
|
# ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
|
|
# ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
|
|
# ("protein2 with mask","KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
|
|
# ("protein3", "K A <mask> I S Q"),
|
|
# ]
|
|
# batch_labels, batch_strs, batch_tokens = batch_converter(data)
|
|
#
|
|
# # Extract per-residue representations (on CPU)
|
|
# with torch.no_grad():
|
|
# results = model(batch_tokens, repr_layers=[33], return_contacts=True)
|
|
# token_representations = results["representations"][33]
|
|
#
|
|
# # Generate per-sequence representations via averaging
|
|
# # NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
|
|
# sequence_representations = []
|
|
# for i, (_, seq) in enumerate(data):
|
|
# sequence_representations.append(token_representations[i, 1 : len(seq) + 1].mean(0))
|
|
# print(sequence_representations)
|