mirror of
https://github.com/junliu621/PPLM.git
synced 2026-06-04 14:24:22 +08:00
471 lines
16 KiB
Python
471 lines
16 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import itertools
|
|
import os
|
|
from typing import Sequence, Tuple, List, Union
|
|
import pickle
|
|
import re
|
|
import shutil
|
|
import torch
|
|
from pathlib import Path
|
|
from pplm.constants import proteinseq_toks
|
|
|
|
RawMSA = Sequence[Tuple[str, str]]
|
|
|
|
class FastaBatchedDataset(object):
|
|
def __init__(self, sequence_labels, sequence_strs):
|
|
self.sequence_labels = list(sequence_labels)
|
|
self.sequence_strs = list(sequence_strs)
|
|
|
|
@classmethod
|
|
def from_file(cls, fasta_file):
|
|
sequence_labels, sequence_strs = [], []
|
|
cur_seq_label = None
|
|
buf = []
|
|
|
|
def _flush_current_seq():
|
|
nonlocal cur_seq_label, buf
|
|
if cur_seq_label is None:
|
|
return
|
|
sequence_labels.append(cur_seq_label)
|
|
sequence_strs.append("".join(buf))
|
|
cur_seq_label = None
|
|
buf = []
|
|
|
|
with open(fasta_file, "r") as infile:
|
|
for line_idx, line in enumerate(infile):
|
|
if line.startswith(">"): # label line
|
|
_flush_current_seq()
|
|
line = line[1:].strip()
|
|
if len(line) > 0:
|
|
cur_seq_label = line
|
|
else:
|
|
cur_seq_label = f"seqnum{line_idx:09d}"
|
|
else: # sequence line
|
|
buf.append(line.strip())
|
|
|
|
_flush_current_seq()
|
|
|
|
assert len(set(sequence_labels)) == len(
|
|
sequence_labels
|
|
), "Found duplicate sequence labels"
|
|
|
|
return cls(sequence_labels, sequence_strs)
|
|
|
|
def __len__(self):
|
|
return len(self.sequence_labels)
|
|
|
|
def __getitem__(self, idx):
|
|
return self.sequence_labels[idx], self.sequence_strs[idx]
|
|
|
|
def get_batch_indices(self, toks_per_batch, extra_toks_per_seq=0):
|
|
sizes = [(len(s), i) for i, s in enumerate(self.sequence_strs)]
|
|
sizes.sort()
|
|
batches = []
|
|
buf = []
|
|
max_len = 0
|
|
|
|
def _flush_current_buf():
|
|
nonlocal max_len, buf
|
|
if len(buf) == 0:
|
|
return
|
|
batches.append(buf)
|
|
buf = []
|
|
max_len = 0
|
|
|
|
for sz, i in sizes:
|
|
sz += extra_toks_per_seq
|
|
if max(sz, max_len) * (len(buf) + 1) > toks_per_batch:
|
|
_flush_current_buf()
|
|
max_len = max(max_len, sz)
|
|
buf.append(i)
|
|
|
|
_flush_current_buf()
|
|
return batches
|
|
|
|
|
|
class Alphabet(object):
|
|
def __init__(
|
|
self,
|
|
standard_toks: Sequence[str],
|
|
prepend_toks: Sequence[str] = ("<null_0>", "<pad>", "<eos>", "<unk>"),
|
|
append_toks: Sequence[str] = ("<cls>", "<mask>", "<sep>"),
|
|
prepend_bos: bool = True,
|
|
append_eos: bool = False,
|
|
use_msa: bool = False,
|
|
):
|
|
self.standard_toks = list(standard_toks)
|
|
self.prepend_toks = list(prepend_toks)
|
|
self.append_toks = list(append_toks)
|
|
self.prepend_bos = prepend_bos
|
|
self.append_eos = append_eos
|
|
self.use_msa = use_msa
|
|
|
|
self.all_toks = list(self.prepend_toks)
|
|
self.all_toks.extend(self.standard_toks)
|
|
for i in range((8 - (len(self.all_toks) % 8)) % 8):
|
|
self.all_toks.append(f"<null_{i + 1}>")
|
|
self.all_toks.extend(self.append_toks)
|
|
|
|
self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)}
|
|
|
|
self.unk_idx = self.tok_to_idx["<unk>"]
|
|
self.padding_idx = self.get_idx("<pad>")
|
|
self.cls_idx = self.get_idx("<cls>") #begin
|
|
self.mask_idx = self.get_idx("<mask>")
|
|
self.eos_idx = self.get_idx("<eos>") #end
|
|
self.all_special_tokens = ['<eos>', '<unk>', '<pad>', '<cls>', '<mask>']
|
|
self.unique_no_split_tokens = self.all_toks
|
|
|
|
def __len__(self):
|
|
return len(self.all_toks)
|
|
|
|
def get_idx(self, tok):
|
|
return self.tok_to_idx.get(tok, self.unk_idx)
|
|
|
|
def get_tok(self, ind):
|
|
return self.all_toks[ind]
|
|
|
|
def to_dict(self):
|
|
return self.tok_to_idx.copy()
|
|
|
|
def get_batch_converter(self, truncation_seq_length: int = None):
|
|
if self.use_msa:
|
|
return MSABatchConverter(self, truncation_seq_length)
|
|
else:
|
|
return BatchConverter(self, truncation_seq_length)
|
|
|
|
@classmethod
|
|
def from_architecture(cls) -> "Alphabet":
|
|
standard_toks = proteinseq_toks["toks"]
|
|
prepend_toks = ("<cls>", "<pad>", "<eos>", "<unk>")
|
|
append_toks = ("<mask>",)
|
|
prepend_bos = True
|
|
append_eos = True
|
|
use_msa = False
|
|
|
|
return cls(standard_toks, prepend_toks, append_toks, prepend_bos, append_eos, use_msa)
|
|
|
|
def _tokenize(self, text) -> str:
|
|
return text.split()
|
|
|
|
def tokenize(self, text, **kwargs) -> List[str]:
|
|
"""
|
|
Inspired by https://github.com/huggingface/transformers/blob/master/src/transformers/tokenization_utils.py
|
|
Converts a string in a sequence of tokens, using the tokenizer.
|
|
|
|
Args:
|
|
text (:obj:`str`):
|
|
The sequence to be encoded.
|
|
|
|
Returns:
|
|
:obj:`List[str]`: The list of tokens.
|
|
"""
|
|
|
|
def split_on_token(tok, text):
|
|
result = []
|
|
split_text = text.split(tok)
|
|
for i, sub_text in enumerate(split_text):
|
|
# AddedToken can control whitespace stripping around them.
|
|
# We use them for GPT2 and Roberta to have different behavior depending on the special token
|
|
# Cf. https://github.com/huggingface/transformers/pull/2778
|
|
# and https://github.com/huggingface/transformers/issues/3788
|
|
# We strip left and right by default
|
|
if i < len(split_text) - 1:
|
|
sub_text = sub_text.rstrip()
|
|
if i > 0:
|
|
sub_text = sub_text.lstrip()
|
|
|
|
if i == 0 and not sub_text:
|
|
result.append(tok)
|
|
elif i == len(split_text) - 1:
|
|
if sub_text:
|
|
result.append(sub_text)
|
|
else:
|
|
pass
|
|
else:
|
|
if sub_text:
|
|
result.append(sub_text)
|
|
result.append(tok)
|
|
return result
|
|
|
|
def split_on_tokens(tok_list, text):
|
|
if not text.strip():
|
|
return []
|
|
|
|
tokenized_text = []
|
|
text_list = [text]
|
|
for tok in tok_list:
|
|
tokenized_text = []
|
|
for sub_text in text_list:
|
|
if sub_text not in self.unique_no_split_tokens:
|
|
tokenized_text.extend(split_on_token(tok, sub_text))
|
|
else:
|
|
tokenized_text.append(sub_text)
|
|
text_list = tokenized_text
|
|
|
|
return list(
|
|
itertools.chain.from_iterable(
|
|
(
|
|
self._tokenize(token)
|
|
if token not in self.unique_no_split_tokens
|
|
else [token]
|
|
for token in tokenized_text
|
|
)
|
|
)
|
|
)
|
|
|
|
no_split_token = self.unique_no_split_tokens
|
|
tokenized_text = split_on_tokens(no_split_token, text)
|
|
return tokenized_text
|
|
|
|
def encode(self, text):
|
|
return [self.tok_to_idx[tok] for tok in self.tokenize(text)]
|
|
|
|
|
|
class BatchConverter(object):
|
|
"""Callable to convert an unprocessed (labels + strings) batch to a
|
|
processed (labels + tensor) batch.
|
|
"""
|
|
|
|
def __init__(self, alphabet, truncation_seq_length: int = None):
|
|
self.alphabet = alphabet
|
|
self.truncation_seq_length = truncation_seq_length
|
|
|
|
def __call__(self, raw_batch: Sequence[Tuple[str, str]]):
|
|
# RoBERTa uses an eos token, while ESM-1 does not.
|
|
batch_size = len(raw_batch)
|
|
batch_labels, seq_str_list = zip(*raw_batch)
|
|
seq_encoded_list = [self.alphabet.encode(seq_str) for seq_str in seq_str_list]
|
|
if self.truncation_seq_length:
|
|
seq_encoded_list = [seq_str[:self.truncation_seq_length] for seq_str in seq_encoded_list]
|
|
max_len = max(len(seq_encoded) for seq_encoded in seq_encoded_list)
|
|
tokens = torch.empty(
|
|
(
|
|
batch_size,
|
|
max_len + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos),
|
|
),
|
|
dtype=torch.int64,
|
|
)
|
|
tokens.fill_(self.alphabet.padding_idx)
|
|
labels = []
|
|
strs = []
|
|
|
|
for i, (label, seq_str, seq_encoded) in enumerate(
|
|
zip(batch_labels, seq_str_list, seq_encoded_list)
|
|
):
|
|
labels.append(label)
|
|
strs.append(seq_str)
|
|
if self.alphabet.prepend_bos:
|
|
tokens[i, 0] = self.alphabet.cls_idx
|
|
seq = torch.tensor(seq_encoded, dtype=torch.int64)
|
|
tokens[
|
|
i,
|
|
int(self.alphabet.prepend_bos) : len(seq_encoded)
|
|
+ int(self.alphabet.prepend_bos),
|
|
] = seq
|
|
if self.alphabet.append_eos:
|
|
tokens[i, len(seq_encoded) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx
|
|
|
|
return labels, strs, tokens
|
|
|
|
|
|
|
|
class MSABatchConverter(BatchConverter):
|
|
def __call__(self, inputs: Union[Sequence[RawMSA], RawMSA]):
|
|
if isinstance(inputs[0][0], str):
|
|
# Input is a single MSA
|
|
raw_batch: Sequence[RawMSA] = [inputs] # type: ignore
|
|
else:
|
|
raw_batch = inputs # type: ignore
|
|
|
|
batch_size = len(raw_batch)
|
|
max_alignments = max(len(msa) for msa in raw_batch)
|
|
max_seqlen = max(len(msa[0][1]) for msa in raw_batch)
|
|
|
|
tokens = torch.empty(
|
|
(
|
|
batch_size,
|
|
max_alignments,
|
|
max_seqlen + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos),
|
|
),
|
|
dtype=torch.int64,
|
|
)
|
|
tokens.fill_(self.alphabet.padding_idx)
|
|
labels = []
|
|
strs = []
|
|
|
|
for i, msa in enumerate(raw_batch):
|
|
msa_seqlens = set(len(seq) for _, seq in msa)
|
|
if not len(msa_seqlens) == 1:
|
|
raise RuntimeError(
|
|
"Received unaligned sequences for input to MSA, all sequence "
|
|
"lengths must be equal."
|
|
)
|
|
msa_labels, msa_strs, msa_tokens = super().__call__(msa)
|
|
labels.append(msa_labels)
|
|
strs.append(msa_strs)
|
|
tokens[i, : msa_tokens.size(0), : msa_tokens.size(1)] = msa_tokens
|
|
|
|
return labels, strs, tokens
|
|
|
|
|
|
def read_fasta(
|
|
path,
|
|
keep_gaps=True,
|
|
keep_insertions=True,
|
|
to_upper=False,
|
|
):
|
|
with open(path, "r") as f:
|
|
for result in read_alignment_lines(
|
|
f, keep_gaps=keep_gaps, keep_insertions=keep_insertions, to_upper=to_upper
|
|
):
|
|
yield result
|
|
|
|
|
|
def read_alignment_lines(
|
|
lines,
|
|
keep_gaps=True,
|
|
keep_insertions=True,
|
|
to_upper=False,
|
|
):
|
|
seq = desc = None
|
|
|
|
def parse(s):
|
|
if not keep_gaps:
|
|
s = re.sub("-", "", s)
|
|
if not keep_insertions:
|
|
s = re.sub("[a-z]", "", s)
|
|
return s.upper() if to_upper else s
|
|
|
|
for line in lines:
|
|
# Line may be empty if seq % file_line_width == 0
|
|
if len(line) > 0 and line[0] == ">":
|
|
if seq is not None:
|
|
yield desc, parse(seq)
|
|
desc = line.strip().lstrip(">")
|
|
seq = ""
|
|
else:
|
|
assert isinstance(seq, str)
|
|
seq += line.strip()
|
|
assert isinstance(seq, str) and isinstance(desc, str)
|
|
yield desc, parse(seq)
|
|
|
|
|
|
class ESMStructuralSplitDataset(torch.utils.data.Dataset):
|
|
"""
|
|
Structural Split Dataset as described in section A.10 of the supplement of our paper.
|
|
https://doi.org/10.1101/622803
|
|
|
|
We use the full version of SCOPe 2.07, clustered at 90% sequence identity,
|
|
generated on January 23, 2020.
|
|
|
|
For each SCOPe domain:
|
|
- We extract the sequence from the corresponding PDB file
|
|
- We extract the 3D coordinates of the Carbon beta atoms, aligning them
|
|
to the sequence. We put NaN where Cb atoms are missing.
|
|
- From the 3D coordinates, we calculate a pairwise distance map, based
|
|
on L2 distance
|
|
- We use DSSP to generate secondary structure labels for the corresponding
|
|
PDB file. This is also aligned to the sequence. We put - where SSP
|
|
labels are missing.
|
|
|
|
For each SCOPe classification level of family/superfamily/fold (in order of difficulty),
|
|
we have split the data into 5 partitions for cross validation. These are provided
|
|
in a downloaded splits folder, in the format:
|
|
splits/{split_level}/{cv_partition}/{train|valid}.txt
|
|
where train is the partition and valid is the concatentation of the remaining 4.
|
|
|
|
For each SCOPe domain, we provide a pkl dump that contains:
|
|
- seq : The domain sequence, stored as an L-length string
|
|
- ssp : The secondary structure labels, stored as an L-length string
|
|
- dist : The distance map, stored as an LxL numpy array
|
|
- coords : The 3D coordinates, stored as an Lx3 numpy array
|
|
|
|
"""
|
|
|
|
base_folder = "structural-data"
|
|
file_list = [
|
|
# url tar filename filename MD5 Hash
|
|
(
|
|
"https://dl.fbaipublicfiles.com/fair-esm/structural-data/splits.tar.gz",
|
|
"splits.tar.gz",
|
|
"splits",
|
|
"456fe1c7f22c9d3d8dfe9735da52411d",
|
|
),
|
|
(
|
|
"https://dl.fbaipublicfiles.com/fair-esm/structural-data/pkl.tar.gz",
|
|
"pkl.tar.gz",
|
|
"pkl",
|
|
"644ea91e56066c750cd50101d390f5db",
|
|
),
|
|
]
|
|
|
|
def __init__(
|
|
self,
|
|
split_level,
|
|
cv_partition,
|
|
split,
|
|
root_path=os.path.expanduser("~/.cache/torch/data/pplm"),
|
|
download=False,
|
|
):
|
|
super().__init__()
|
|
assert split in [
|
|
"train",
|
|
"valid",
|
|
], "train_valid must be 'train' or 'valid'"
|
|
self.root_path = root_path
|
|
self.base_path = os.path.join(self.root_path, self.base_folder)
|
|
|
|
# check if root path has what you need or else download it
|
|
if download:
|
|
self.download()
|
|
|
|
self.split_file = os.path.join(
|
|
self.base_path, "splits", split_level, cv_partition, f"{split}.txt"
|
|
)
|
|
self.pkl_dir = os.path.join(self.base_path, "pkl")
|
|
self.names = []
|
|
with open(self.split_file) as f:
|
|
self.names = f.read().splitlines()
|
|
|
|
def __len__(self):
|
|
return len(self.names)
|
|
|
|
def _check_exists(self) -> bool:
|
|
for (_, _, filename, _) in self.file_list:
|
|
fpath = os.path.join(self.base_path, filename)
|
|
if not os.path.exists(fpath) or not os.path.isdir(fpath):
|
|
return False
|
|
return True
|
|
|
|
def download(self):
|
|
|
|
if self._check_exists():
|
|
print("Files already downloaded and verified")
|
|
return
|
|
|
|
from torchvision.datasets.utils import download_url
|
|
|
|
for url, tar_filename, filename, md5_hash in self.file_list:
|
|
download_path = os.path.join(self.base_path, tar_filename)
|
|
download_url(url=url, root=self.base_path, filename=tar_filename, md5=md5_hash)
|
|
shutil.unpack_archive(download_path, self.base_path)
|
|
|
|
def __getitem__(self, idx):
|
|
"""
|
|
Returns a dict with the following entires
|
|
- seq : Str (domain sequence)
|
|
- ssp : Str (SSP labels)
|
|
- dist : np.array (distance map)
|
|
- coords : np.array (3D coordinates)
|
|
"""
|
|
name = self.names[idx]
|
|
pkl_fname = os.path.join(self.pkl_dir, name[1:3], f"{name}.pkl")
|
|
with open(pkl_fname, "rb") as f:
|
|
obj = pickle.load(f)
|
|
return obj
|