Data¶
Datasets¶
- class rl4co.data.dataset.ExtraKeyDataset(dataset, extra)[source]¶
Bases:
DatasetDataset that includes an extra key to add to the data dict. This is useful for adding a REINFORCE baseline reward to the data dict.
- class rl4co.data.dataset.TensorDictDataset(data)[source]¶
Bases:
DatasetDataset compatible with TensorDicts. It is better to “disassemble” the TensorDict into a list of dicts. See
tensordict_collate_fnfor more details.Note
Check out the issue on tensordict for more details: https://github.com/pytorch-labs/tensordict/issues/374. Note that directly indexing TensorDicts may be faster in creating the dataset but uses > 3x more CPU.
- rl4co.data.dataset.tensordict_collate_fn(batch)[source]¶
Collate function compatible with TensorDicts. Reassemble the list of dicts into a TensorDict; seems to be way more efficient than using a TensorDictDataset.
Note
Check out the issue on tensordict for more details: https://github.com/pytorch-labs/tensordict/issues/374. Note that directly indexing TensorDicts may be faster in creating the dataset but uses > 3x more CPU.
Data Generation¶
- rl4co.data.generate_data.generate_dataset(filename=None, data_dir='data', name=None, problem='all', data_distribution='all', dataset_size=10000, graph_sizes=[20, 50, 100], overwrite=False, seed=1234, disable_warning=True)[source]¶
We keep a similar structure as in Kool et al. 2019 but save and load the data as npz This is way faster and more memory efficient than pickle and also allows for easy transfer to TensorDict
- rl4co.data.generate_data.generate_default_datasets(data_dir)[source]¶
Generate the default datasets used in the paper and save them to data_dir/problem
- rl4co.data.generate_data.generate_env_data(env_type, *args, **kwargs)[source]¶
Generate data for a given environment type in the form of a dictionary
- rl4co.data.generate_data.generate_mdpp_data(dataset_size, size=10, num_probes_min=2, num_probes_max=5, num_keepout_min=1, num_keepout_max=50, lock_size=True)[source]¶
Generate data for the nDPP problem. If lock_size is True, then the size if fixed and we skip the size argument if it is not 10. This is because the RL environment is based on a real-world PCB (parametrized with data)
Transforms¶
- class rl4co.data.transforms.StateAugmentation(env_name=None, num_augment=8, use_dihedral_8=False, normalize=False)[source]¶
Bases:
objectAugment state by N times via symmetric rotation/reflection transform
- Parameters:
- rl4co.data.transforms.dihedral_8_augmentation(xy)[source]¶
Augmentation (x8) for grid-based data (x, y) as done in POMO. This is a Dihedral group of order 8 (rotations and reflections) https://en.wikipedia.org/wiki/Examples_of_groups#dihedral_group_of_order_8
- Parameters:
xy¶ (
Tensor) – [batch, graph, 2] tensor of x and y coordinates- Return type:
Tensor
- rl4co.data.transforms.dihedral_8_augmentation_wrapper(xy, reduce=True, *args, **kw)[source]¶
Wrapper for dihedral_8_augmentation. If reduce, only return the first 1/8 of the augmented data since the augmentation augments the data 8 times.
- Return type:
Tensor
- rl4co.data.transforms.env_aug_feats(env_name=None)[source]¶
What features to augment for a given environment Usually, locs already includes depot, so we don’t need to augment depot
- rl4co.data.transforms.symmetric_augmentation(xy, num_augment=8, first_augment=False)[source]¶
Augment xy data by num_augment times via symmetric rotation transform and concatenate to original data