Make return a single tensor

This commit is contained in:
Kevin Wu
2022-07-05 16:44:48 +00:00
parent 57603e6959
commit 4208a695f5

View File

@@ -6,6 +6,8 @@ data loader object
import os, sys
import logging
import json
from typing import *
import torch
from tqdm.auto import tqdm
@@ -58,7 +60,7 @@ class CathConsecutiveAnglesDataset(Dataset):
"""Returns the length of this object"""
return len(self.structures)
def __getitem__(self, index: int):
def __getitem__(self, index: int) -> torch.Tensor:
if not 0 <= index < len(self):
raise IndexError(index)
@@ -83,7 +85,9 @@ class CathConsecutiveAnglesDataset(Dataset):
logging.debug(
f"Post slice shape: {dist_slice.shape, omega_slice.shape, theta_slice.shape, phi_slice.shape}"
)
return dist_slice, omega_slice, theta_slice, phi_slice
all_values = np.array([dist_slice, omega_slice, theta_slice, phi_slice])
assert all_values.shape == (4, n - 1)
return torch.from_numpy(all_values)
def main():