Files
foldingdiff/protdiff/datasets.py
2022-07-25 18:02:04 +00:00

318 lines
12 KiB
Python

"""
Contains source code for loading in data and creating requisite PyTorch
data loader object
"""
import functools
import multiprocessing
import os, sys
import logging
import json
from typing import *
import torch
from tqdm.auto import tqdm
import numpy as np
from torch.utils.data import Dataset
CATH_DIR = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data/cath"
)
assert os.path.isdir(CATH_DIR), f"Expected cath data at {CATH_DIR}"
from sequence_models import pdb_utils
import beta_schedules
import utils
class CathConsecutiveAnglesDataset(Dataset):
"""
Represent proteins as their constituent angles instead of 3D coordinates
The three angles phi, psi, and omega determine the backbone structure.
Omega is typically fixed ~180 degrees in most cases.
By default the angles are given in range of [-pi, pi] but we want to shift
these to [0, 2pi] so we can have easier modulo math. This behavior can be
toggled in shift_to_zero_twopi
Useful reading:
- https://proteinstructures.com/structure/ramachandran-plot/
- https://foldit.fandom.com/wiki/Backbone_angle
- http://www.cryst.bbk.ac.uk/PPS95/course/9_quaternary/3_geometry/torsion.html
- https://swissmodel.expasy.org/course/text/chapter1.htm
- https://www.nature.com/articles/s41598-020-76317-6
- https://userguide.mdanalysis.org/1.1.1/examples/analysis/structure/dihedrals.html
"""
def __init__(
self,
split: Optional[Literal["train", "test", "validation"]] = None,
pad: int = 512,
shift_to_zero_twopi: bool = True,
toy: bool = False,
) -> None:
super().__init__()
self.pad = pad
self.shift_to_zero_twopi = shift_to_zero_twopi
# json list file -- each line is a json
data_file = os.path.join(CATH_DIR, "chain_set.jsonl")
assert os.path.isfile(data_file)
self.structures = []
with open(data_file) as source:
for i, line in enumerate(source):
structure = json.loads(line.strip())
self.structures.append(structure)
if toy and i > 150:
break
# Get data split if given
self.split = split
if split is not None:
splitfile = os.path.join(CATH_DIR, "chain_set_splits.json")
with open(splitfile) as source:
data_splits = json.load(source)
assert split in data_splits.keys(), f"Unrecognized split: {split}"
# Subset self.structures to only the given names
orig_len = len(self.structures)
self.structures = [
s for s in self.structures if s["name"] in set(data_splits[self.split])
]
logging.info(
f"Split {self.split} contains {len(self.structures)}/{orig_len} examples"
)
# Generate angles in parallel and attach them to corresponding structures
pool = multiprocessing.Pool(multiprocessing.cpu_count())
pfunc = functools.partial(
coords_to_angles, shift_angles_positive=self.shift_to_zero_twopi
)
angles = pool.map(pfunc, [d["coords"] for d in self.structures], chunksize=250)
pool.close()
pool.join()
for s, a in zip(self.structures, angles):
s["angles"] = a
# Remove items with nan in angles/structures
orig_count = len(self.structures)
self.structures = [s for s in self.structures if s["angles"] is not None]
new_count = len(self.structures)
logging.info(f"Removed structures with nan {orig_count} -> {new_count}")
# Aggregate the lengths
self.all_lengths = [s["angles"].shape[0] for s in self.structures]
self._length_rng = np.random.default_rng(seed=6489)
logging.info(
f"Length of angles: {np.min(self.all_lengths)}-{np.max(self.all_lengths)}, mean {np.mean(self.all_lengths)}"
)
def sample_length(self) -> int:
"""
Sample a observed length of a sequence
"""
l = self._length_rng.choice(self.all_lengths)
return l
def __len__(self) -> int:
"""Returns the length of this object"""
return len(self.structures)
def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
if not 0 <= index < len(self):
raise IndexError(index)
angles = self.structures[index]["angles"]
assert angles is not None
# Pad or trim to given length
l = min(self.pad, angles.shape[0])
attn_mask = torch.zeros(size=(self.pad,))
attn_mask[:l] = 1.0
if angles.shape[0] < self.pad:
orig_shape = angles.shape
angles = np.pad(
angles,
((0, self.pad - angles.shape[0]), (0, 0)),
mode="constant",
constant_values=0,
)
logging.debug(f"Padded {orig_shape} -> {angles.shape}")
elif angles.shape[0] > self.pad:
angles = angles[: self.pad]
position_ids = torch.arange(start=0, end=self.pad, step=1, dtype=torch.long)
retval = torch.from_numpy(angles).float()
return {"angles": retval, "attn_mask": attn_mask, "position_ids": position_ids}
class NoisedAnglesDataset(Dataset):
"""
class that produces noised outputs given a wrapped dataset.
Wrapped dset should return a tensor from __getitem__ if dset_key
is not specified; otherwise, returns a dictionary where the item
to noise is under dset_key
modulo can be given as either a float or a list of floats
"""
def __init__(
self,
dset: Dataset,
dset_key: Optional[str] = None,
timesteps: int = 1000,
beta_schedule: beta_schedules.SCHEDULES = "linear",
modulo: Optional[Union[float, Iterable[float]]] = None,
) -> None:
super().__init__()
self.dset = dset
self.dset_key = dset_key
self.modulo = modulo
self.timesteps = timesteps
self.schedule = beta_schedule
self.betas = beta_schedules.get_variance_schedule(beta_schedule, timesteps)
self.alphas = 1.0 - self.betas
self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
def __len__(self) -> int:
return len(self.dset)
def __getitem__(
self, index, use_t_val: Optional[int] = None
) -> Dict[str, torch.Tensor]:
"""
Gets the i-th item in the dataset and adds noise
use_t_val is useful for manually querying specific timepoints
"""
item = self.dset.__getitem__(index)
# If wrapped dset returns a dictionary then we extract the item to noise
if self.dset_key is not None:
assert isinstance(item, dict)
vals = item[self.dset_key]
else:
vals = item
assert isinstance(vals, torch.Tensor), f"Expected tensor but got {type(vals)}"
# Sample a random timepoint and add corresponding noise
if use_t_val is not None:
t_val = np.clip(np.array([use_t_val]), 0, self.timesteps - 1)
t = torch.from_numpy(t_val).long()
else:
t = torch.randint(0, self.timesteps, (1,)).long()
sqrt_alphas_cumprod_t = utils.extract(self.sqrt_alphas_cumprod, t, vals.shape)
sqrt_one_minus_alphas_cumprod_t = utils.extract(
self.sqrt_one_minus_alphas_cumprod, t, vals.shape
)
noise = torch.randn_like(vals)
noised_vals = (
sqrt_alphas_cumprod_t * vals + sqrt_one_minus_alphas_cumprod_t * noise
)
# If modulo is given ensure that we do modulo
if self.modulo is not None:
try:
iter(self.modulo)
assert (
len(self.modulo) == noised_vals.shape[1]
), f"Mismatched shapes: {noised_vals.shape} vs {len(self.modulo)} modulo terms"
# Should have shape (seq_len, 4)
noised_vals = torch.vstack(
[t % m if m != 0 else t for t, m in zip(noised_vals.T, self.modulo)]
).T
except TypeError:
# not iterable --> float
noised_vals = noised_vals % self.modulo
retval = {
"corrupted": noised_vals,
"t": t,
"known_noise": noise,
}
# Update dictionary if wrapped dset returns dicts, else just return
if isinstance(item, dict):
assert item.keys().isdisjoint(retval.keys())
item.update(retval)
return item
return retval
def coords_to_angles(
coords: Dict[str, List[List[float]]], shift_angles_positive: bool = True
) -> Optional[np.ndarray]:
"""
Sanitize the coordinates to not have NaN and convert them into
arrays of angles. If sanitization fails, return None
if shift_angles_positive, take the angles given in [-pi, pi] range and
shift them to [0, 2pi] range
"""
first_valid_idx, last_valid_idx = 0, len(coords["N"])
# Walk through coordinates and trim trailing nan
for k in ["N", "CA", "C"]:
logging.debug(f"{k}:\t{coords[k][:5]}")
arr = np.array(coords[k])
# Get all valid indices
valid_idx = np.where(~np.any(np.isnan(arr), axis=1))[0]
first_valid_idx = max(first_valid_idx, np.min(valid_idx))
last_valid_idx = min(last_valid_idx, np.max(valid_idx) + 1)
logging.debug(f"Trimming nans keeps {first_valid_idx}:{last_valid_idx}")
for k in ["N", "CA", "C"]:
coords[k] = coords[k][first_valid_idx:last_valid_idx]
arr = np.array(coords[k])
if np.any(np.isnan(arr)):
logging.debug("Got nan in middle of array")
return None
angles = pdb_utils.process_coords(coords)
# https://www.rosettacommons.org/docs/latest/application_documentation/trRosetta/trRosetta#application-purpose_a-note-on-nomenclature
# omega = inter-residue dihedral angle between CA/CB of first and CB/CA of second
# theta = inter-residue dihedral angle between N, CA, CB of first and CB of second
# phi = inter-residue angle between CA and CB of first and CB of second
dist, omega, theta, phi = angles
assert dist.shape == omega.shape == theta.shape == phi.shape
logging.debug(f"Pre slice shape: {dist.shape, omega.shape, theta.shape, phi.shape}")
# Slice out so that we have the angles and distances between the n and n+1 items
n = dist.shape[0]
indices_i = np.arange(n - 1)
indices_j = indices_i + 1
dist_slice = dist[indices_i, indices_j]
omega_slice = omega[indices_i, indices_j]
theta_slice = theta[indices_i, indices_j]
phi_slice = phi[indices_i, indices_j]
logging.debug(
f"Post slice shape: {dist_slice.shape, omega_slice.shape, theta_slice.shape, phi_slice.shape}"
)
all_values = np.array([dist_slice, omega_slice, theta_slice, phi_slice]).T
assert all_values.shape == (n - 1, 4)
assert np.all(
np.logical_and(all_values[:, 1:] <= np.pi, all_values[:, 1:] >= -np.pi,)
), "Angle values outside of [-pi, pi] range"
if shift_angles_positive:
all_values[:, 1:] += np.pi
return all_values
def main():
dset = CathConsecutiveAnglesDataset(toy=False, split="train")
noised_dset = NoisedAnglesDataset(dset, dset_key="angles")
x = noised_dset[0]
for k, v in x.items():
print(k)
print(v)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()