Data¶
Datasets¶
- class rl4co.data.dataset.ExtraKeyDataset(dataset, extra, key_name='extra')[source]¶
Bases:
TensorDictDatasetDataset that includes an extra key to add to the data dict. This is useful for adding a REINFORCE baseline reward to the data dict. Note that this is faster to instantiate than using list comprehension.
- class rl4co.data.dataset.TensorDictDataset(td)[source]¶
Bases:
DatasetDataset compatible with TensorDicts with low CPU usage. Fast loading but somewhat slow instantiation due to list comprehension since we “disassemble” the TensorDict into a list of dicts.
Note
Check out the issue on tensordict for more details: https://github.com/pytorch-labs/tensordict/issues/374.
- class rl4co.data.dataset.TensorDictDatasetFastGeneration(td)[source]¶
Bases:
DatasetDataset compatible with TensorDicts. Similar performance in loading to list comprehension, but is faster in instantiation than
TensorDictDatasetList(more than 10x faster).Warning
Note that directly indexing TensorDicts may be faster in creating the dataset but uses > 3x more CPU. We may generally recommend using the
TensorDictDatasetListNote
Check out the issue on tensordict for more details: https://github.com/pytorch-labs/tensordict/issues/374.
Data Generation¶
- rl4co.data.generate_data.generate_dataset(filename=None, data_dir='data', name=None, problem='all', data_distribution='all', dataset_size=10000, graph_sizes=[20, 50, 100], overwrite=False, seed=1234, disable_warning=True, distributions_per_problem=None)[source]¶
We keep a similar structure as in Kool et al. 2019 but save and load the data as npz This is way faster and more memory efficient than pickle and also allows for easy transfer to TensorDict
- Parameters:
filename¶ (
Union[str,List[str]]) – Filename to save the data to. If None, the data is saved to data_dir/problem/problem_graph_size_seed.npz. Defaults to None.data_dir¶ (
str) – Directory to save the data to. Defaults to “data”.name¶ (
str) – Name of the dataset. Defaults to None.problem¶ (
Union[str,List[str]]) – Problem to generate data for. Defaults to “all”.data_distribution¶ (
str) – Data distribution to generate data for. Defaults to “all”.dataset_size¶ (
int) – Number of datasets to generate. Defaults to 10000.graph_sizes¶ (
Union[int,List[int]]) – Graph size to generate data for. Defaults to [20, 50, 100].overwrite¶ (
bool) – Whether to overwrite existing files. Defaults to False.seed¶ (
int) – Random seed. Defaults to 1234.disable_warning¶ (
bool) – Whether to disable warnings. Defaults to True.distributions_per_problem¶ (
Union[int,dict]) – Number of distributions to generate per problem. Defaults to None.
- rl4co.data.generate_data.generate_default_datasets(data_dir, generate_eda=False)[source]¶
Generate the default datasets used in the paper and save them to data_dir/problem
- rl4co.data.generate_data.generate_env_data(env_type, *args, **kwargs)[source]¶
Generate data for a given environment type in the form of a dictionary
- rl4co.data.generate_data.generate_mdpp_data(dataset_size, size=10, num_probes_min=2, num_probes_max=5, num_keepout_min=1, num_keepout_max=50, lock_size=True)[source]¶
Generate data for the nDPP problem. If lock_size is True, then the size if fixed and we skip the size argument if it is not 10. This is because the RL environment is based on a real-world PCB (parametrized with data)
- rl4co.data.generate_data.generate_op_data(dataset_size, op_size, prize_type='const', max_lengths=None)[source]¶
Transforms¶
- class rl4co.data.transforms.StateAugmentation(env_name=None, num_augment=8, use_dihedral_8=False, normalize=False, feats=None)[source]¶
Bases:
objectAugment state by N times via symmetric rotation/reflection transform
- Parameters:
- rl4co.data.transforms.dihedral_8_augmentation(xy)[source]¶
Augmentation (x8) for grid-based data (x, y) as done in POMO. This is a Dihedral group of order 8 (rotations and reflections) https://en.wikipedia.org/wiki/Examples_of_groups#dihedral_group_of_order_8
- Parameters:
xy¶ (
Tensor) – [batch, graph, 2] tensor of x and y coordinates- Return type:
Tensor
- rl4co.data.transforms.dihedral_8_augmentation_wrapper(xy, reduce=True, *args, **kw)[source]¶
Wrapper for dihedral_8_augmentation. If reduce, only return the first 1/8 of the augmented data since the augmentation augments the data 8 times.
- Return type:
Tensor
- rl4co.data.transforms.env_aug_feats(env_name=None)[source]¶
What features to augment for a given environment Usually, locs already includes depot, so we don’t need to augment depot
- rl4co.data.transforms.symmetric_augmentation(xy, num_augment=8, first_augment=False)[source]¶
Augment xy data by num_augment times via symmetric rotation transform and concatenate to original data