From 4208a695f5df1116f01fd3fb02e53f36191dc6b8 Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Tue, 5 Jul 2022 16:44:48 +0000 Subject: [PATCH] Make return a single tensor --- protdiff/datasets.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/protdiff/datasets.py b/protdiff/datasets.py index 0d845a5..3710210 100644 --- a/protdiff/datasets.py +++ b/protdiff/datasets.py @@ -6,6 +6,8 @@ data loader object import os, sys import logging import json +from typing import * +import torch from tqdm.auto import tqdm @@ -58,7 +60,7 @@ class CathConsecutiveAnglesDataset(Dataset): """Returns the length of this object""" return len(self.structures) - def __getitem__(self, index: int): + def __getitem__(self, index: int) -> torch.Tensor: if not 0 <= index < len(self): raise IndexError(index) @@ -83,7 +85,9 @@ class CathConsecutiveAnglesDataset(Dataset): logging.debug( f"Post slice shape: {dist_slice.shape, omega_slice.shape, theta_slice.shape, phi_slice.shape}" ) - return dist_slice, omega_slice, theta_slice, phi_slice + all_values = np.array([dist_slice, omega_slice, theta_slice, phi_slice]) + assert all_values.shape == (4, n - 1) + return torch.from_numpy(all_values) def main():