Files
openfold/scripts/precompute_esm_embeddings.py

155 lines
4.8 KiB
Python

import argparse
import pathlib
import torch
from openfold.data import parsers
class SequenceDataset(object):
def __init__(self, labels, sequences):
self.labels = labels
self.sequences = sequences
@classmethod
def from_file(cls, fasta_file):
labels, sequences = [], []
with open(fasta_file, "r") as infile:
fasta_str = infile.read()
sequences, labels = parsers.parse_fasta(fasta_str)
assert len(set(labels)) == len(labels),\
"Sequence lables need to be unique. Duplicates found!"
return cls(labels, sequences)
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
return self.labels[idx], self.sequences[idx]
def get_batch_indices(self, toks_per_batch, extra_toks_per_seq):
sizes = [(len(s), i) for i, s in enumerate(self.sequences)]
sizes.sort()
batches = []
buf = []
max_len = 0
def _flush_current_buf():
nonlocal max_len, buf
if len(buf) == 0:
return
batches.append(buf)
buf = []
max_len = 0
for sz, i in sizes:
sz += extra_toks_per_seq
if max(sz, max_len) * (len(buf)+1) > toks_per_batch:
_flush_current_buf()
max_len = max(max_len, sz)
buf.append(i)
_flush_current_buf()
return batches
def main(args):
model, alphabet = torch.hub.load("facebookresearch/esm:main", "esm1b_t33_650M_UR50S")
model.eval()
if torch.cuda.is_available() and not args.nogpu:
print("Using device=GPU")
model = model.cuda()
fasta_file_path = args.fasta_path
dataset = SequenceDataset.from_file(fasta_file_path)
batches = dataset.get_batch_indices(args.toks_per_batch, extra_toks_per_seq=1)
data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=alphabet.get_batch_converter(), batch_sampler=batches
)
print(f"Processed {fasta_file_path}")
args.output_dir.mkdir(parents=True, exist_ok=True)
assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in range(args.repr_layers))
repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in args.repr_layers]
with torch.no_grad():
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
print(f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)")
if torch.cuda.is_available() and not args.nogpu:
toks = toks.to(device="cuda", non_blocking=True)
# ESM cannot handle >1022 residue sequences
if args.truncate:
toks = toks[:1022]
out = model(toks, repr_layers=repr_layers)
logits = out["logits"].to(device="cpu")
representations = {
layer: t.to(device="cpu") for layer, t in out["representations"].items()
}
for i, label in enumerate(labels):
args.output_file = args.output_dir / label / f"{label}.pt"
args.output_file.parent.mkdir(parents=True, exist_ok=True)
result = {"label": label}
if "per_tok" in args.include:
result["representations"] = {
layer: t[i, 1 : len(strs[i] + 1).clone()]
for layer, t in representations.items()
}
torch.save(
result,
args.output_file,
)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description="Extract per-residue ESM embeddings for sequences in a FASTA file, "
"or a dir of FASTA files."
)
parser.add_argument("model",
type=str,
help="Either the name of a pretrained model to download via PyTorch Hub, "
"or path to a PyTorch model file.",
default="esm1b_t33_650M_UR50S"
)
parser.add_argument("fasta_path",
type=pathlib.Path,
help="Path to a FASTA file, or to a dir containing FASTA files."
)
parser.add_argument("output_path",
type=pathlib.Path,
help="Output directory where ESM representations will be stored."
)
parser.add_argument(
"--toks_per_batch",
type=int,
default=4096,
help="Maximum batch size, i.e. number of tokens in batch."
)
parser.add_argument(
"--repr_layers",
type=int,
default=[-1],
nargs="+",
help="Layers indices from which to extract representations (0 to num_layers, inclusive)."
)
parser.add_argument(
"--truncate",
action="store_true",
default=True,
help="Truncate sequences longer than 1022 (ESM restriction). Default: True"
)
parser.add_argument("--nogpu",
action="store_true",
help="Do not use GPU"
)
args = parser.parse_args()
main(args)