mirror of
https://github.com/gcorso/DiffDock.git
synced 2026-06-04 18:04:23 +08:00
102 lines
3.6 KiB
Python
102 lines
3.6 KiB
Python
from collections.abc import Mapping, Sequence
|
|
from typing import List, Optional, Union
|
|
|
|
import torch.utils.data
|
|
from torch.utils.data.dataloader import default_collate
|
|
|
|
from torch_geometric.data import Batch, Dataset
|
|
from torch_geometric.data.data import BaseData
|
|
|
|
|
|
class Collater:
|
|
def __init__(self, follow_batch, exclude_keys):
|
|
self.follow_batch = follow_batch
|
|
self.exclude_keys = exclude_keys
|
|
|
|
def __call__(self, batch):
|
|
batch = [x for x in batch if x is not None]
|
|
elem = batch[0]
|
|
if isinstance(elem, BaseData):
|
|
return Batch.from_data_list(batch, self.follow_batch,
|
|
self.exclude_keys)
|
|
elif isinstance(elem, torch.Tensor):
|
|
return default_collate(batch)
|
|
elif isinstance(elem, float):
|
|
return torch.tensor(batch, dtype=torch.float)
|
|
elif isinstance(elem, int):
|
|
return torch.tensor(batch)
|
|
elif isinstance(elem, str):
|
|
return batch
|
|
elif isinstance(elem, Mapping):
|
|
return {key: self([data[key] for data in batch]) for key in elem}
|
|
elif isinstance(elem, tuple) and hasattr(elem, '_fields'):
|
|
return type(elem)(*(self(s) for s in zip(*batch)))
|
|
elif isinstance(elem, Sequence) and not isinstance(elem, str):
|
|
return [self(s) for s in zip(*batch)]
|
|
|
|
raise TypeError(f'DataLoader found invalid type: {type(elem)}')
|
|
|
|
def collate(self, batch): # Deprecated...
|
|
return self(batch)
|
|
|
|
|
|
class DataLoader(torch.utils.data.DataLoader):
|
|
r"""A data loader which merges data objects from a
|
|
:class:`torch_geometric.data.Dataset` to a mini-batch.
|
|
Data objects can be either of type :class:`~torch_geometric.data.Data` or
|
|
:class:`~torch_geometric.data.HeteroData`.
|
|
|
|
Args:
|
|
dataset (Dataset): The dataset from which to load the data.
|
|
batch_size (int, optional): How many samples per batch to load.
|
|
(default: :obj:`1`)
|
|
shuffle (bool, optional): If set to :obj:`True`, the data will be
|
|
reshuffled at every epoch. (default: :obj:`False`)
|
|
follow_batch (List[str], optional): Creates assignment batch
|
|
vectors for each key in the list. (default: :obj:`None`)
|
|
exclude_keys (List[str], optional): Will exclude each key in the
|
|
list. (default: :obj:`None`)
|
|
**kwargs (optional): Additional arguments of
|
|
:class:`torch.utils.data.DataLoader`.
|
|
"""
|
|
def __init__(
|
|
self,
|
|
dataset: Union[Dataset, List[BaseData]],
|
|
batch_size: int = 1,
|
|
shuffle: bool = False,
|
|
follow_batch: Optional[List[str]] = None,
|
|
exclude_keys: Optional[List[str]] = None,
|
|
**kwargs,
|
|
):
|
|
|
|
if 'collate_fn' in kwargs:
|
|
del kwargs['collate_fn']
|
|
|
|
# Save for PyTorch Lightning:
|
|
self.follow_batch = follow_batch
|
|
self.exclude_keys = exclude_keys
|
|
|
|
super().__init__(
|
|
dataset,
|
|
batch_size,
|
|
shuffle,
|
|
collate_fn=Collater(follow_batch, exclude_keys),
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
def collate_fn(data_list):
|
|
data_list = [x for x in data_list if x is not None]
|
|
return data_list
|
|
|
|
|
|
class DataListLoader(torch.utils.data.DataLoader):
|
|
def __init__(self, dataset: Union[Dataset, List[BaseData]],
|
|
batch_size: int = 1, shuffle: bool = False, **kwargs):
|
|
if 'collate_fn' in kwargs:
|
|
del kwargs['collate_fn']
|
|
|
|
super().__init__(dataset, batch_size=batch_size, shuffle=shuffle,
|
|
collate_fn=collate_fn, **kwargs)
|
|
|