Source code for rl4co.data.dataset
from typing import Union
import torch
from tensordict.tensordict import TensorDict
from torch.utils.data import Dataset
[docs]class TensorDictDataset(Dataset):
"""Dataset compatible with TensorDicts with low CPU usage.
Fast loading but somewhat slow instantiation due to list comprehension since we
"disassemble" the TensorDict into a list of dicts.
Note:
Check out the issue on tensordict for more details:
https://github.com/pytorch-labs/tensordict/issues/374.
"""
def __init__(self, td: TensorDict):
self.data_len = td.batch_size[0]
self.data = [
{key: value[i] for key, value in td.items()} for i in range(self.data_len)
]
def __len__(self):
return self.data_len
def __getitem__(self, idx):
return self.data[idx]
[docs] def add_key(self, key, value):
return ExtraKeyDataset(self, value, key_name=key)
[docs] @staticmethod
def collate_fn(batch: Union[dict, TensorDict]):
"""Collate function compatible with TensorDicts that reassembles a list of dicts."""
return TensorDict(
{key: torch.stack([b[key] for b in batch]) for key in batch[0].keys()},
batch_size=torch.Size([len(batch)]),
_run_checks=False,
)
[docs]class TensorDictDatasetFastGeneration(Dataset):
"""Dataset compatible with TensorDicts.
Similar performance in loading to list comprehension, but is faster in instantiation
than :class:`TensorDictDatasetList` (more than 10x faster).
Warning:
Note that directly indexing TensorDicts may be faster in creating the dataset
but uses > 3x more CPU. We may generally recommend using the :class:`TensorDictDatasetList`
Note:
Check out the issue on tensordict for more details:
https://github.com/pytorch-labs/tensordict/issues/374.
"""
def __init__(self, td: TensorDict):
self.data = td
def __len__(self):
return len(self.data)
def __getitems__(self, index):
# Tricks:
# - batched data loading with `__getitems__` for faster loading
# - avoid directly indexing TensorDicts for faster loading
return TensorDict(
{key: item[index] for key, item in self.data.items()},
batch_size=torch.Size([len(index)]),
_run_checks=False, # faster this way
)
[docs] def add_key(self, key, value):
self.data.update({key: value}) # native method
return self
[docs] @staticmethod
def collate_fn(batch: Union[dict, TensorDict]):
"""Equivalent to collating with `lambda x: x`"""
return batch