Files
openfold/scripts/generate_esm_embeddings.py

27 lines
1.2 KiB
Python

# import torch
#
# # Load ESM-1b model
# model, alphabet = torch.hub.load("facebookresearch/esm:main", "esm1b_t33_650M_UR50S")
# batch_converter = alphabet.get_batch_converter()
#
# # Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
# data = [
# ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
# ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
# ("protein2 with mask","KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
# ("protein3", "K A <mask> I S Q"),
# ]
# batch_labels, batch_strs, batch_tokens = batch_converter(data)
#
# # Extract per-residue representations (on CPU)
# with torch.no_grad():
# results = model(batch_tokens, repr_layers=[33], return_contacts=True)
# token_representations = results["representations"][33]
#
# # Generate per-sequence representations via averaging
# # NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
# sequence_representations = []
# for i, (_, seq) in enumerate(data):
# sequence_representations.append(token_representations[i, 1 : len(seq) + 1].mean(0))
# print(sequence_representations)