mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 13:30:33 +08:00
Make return a single tensor
This commit is contained in:
@@ -6,6 +6,8 @@ data loader object
|
||||
import os, sys
|
||||
import logging
|
||||
import json
|
||||
from typing import *
|
||||
import torch
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
@@ -58,7 +60,7 @@ class CathConsecutiveAnglesDataset(Dataset):
|
||||
"""Returns the length of this object"""
|
||||
return len(self.structures)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
def __getitem__(self, index: int) -> torch.Tensor:
|
||||
if not 0 <= index < len(self):
|
||||
raise IndexError(index)
|
||||
|
||||
@@ -83,7 +85,9 @@ class CathConsecutiveAnglesDataset(Dataset):
|
||||
logging.debug(
|
||||
f"Post slice shape: {dist_slice.shape, omega_slice.shape, theta_slice.shape, phi_slice.shape}"
|
||||
)
|
||||
return dist_slice, omega_slice, theta_slice, phi_slice
|
||||
all_values = np.array([dist_slice, omega_slice, theta_slice, phi_slice])
|
||||
assert all_values.shape == (4, n - 1)
|
||||
return torch.from_numpy(all_values)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
Reference in New Issue
Block a user