Base Environment¶
- class rl4co.envs.common.base.RL4COEnvBase(*, data_dir='data/', train_file=None, val_file=None, test_file=None, val_dataloader_names=None, test_dataloader_names=None, check_solution=True, dataset_cls=<class 'rl4co.data.dataset.TensorDictDataset'>, seed=None, device='cpu', _torchrl_mode=False, **kwargs)[source]¶
Bases:
EnvBaseBase class for RL4CO environments based on TorchRL EnvBase
- Parameters:
data_dir¶ (
str) – Root directory for the datasettrain_file¶ (
str) – Name of the training fileval_file¶ (
str) – Name of the validation filetest_file¶ (
str) – Name of the test fileval_dataloader_names¶ (
list) – Names of the dataloaders to use for validationtest_dataloader_names¶ (
list) – Names of the dataloaders to use for testingcheck_solution¶ (
bool) – Whether to check the validity of the solution at the end of the episodedataset_cls¶ (
callable) – Dataset class to use for the environment (which can influence performance)seed¶ (
int) – Seed for the environmentdevice¶ (
str) – Device to use. Generally, no need to set as tensors are updated on the fly_torchrl_mode¶ (
bool) – Whether to use the TorchRL mode (seestep()for more details)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- check_solution_validity(td, actions)[source]¶
Function to check whether the solution is valid. Can be called by the agent to check the validity of the current state This is called with the full solution (i.e. all actions) at the end of the episode
- Return type:
TensorDict
- dataset(batch_size=[], phase='train', filename=None)[source]¶
Return a dataset of observations Generates the dataset if it does not exist, otherwise loads it from file
- get_action_mask(td)[source]¶
Function to compute the action mask (feasible actions) for the current state Action mask is 1 if the action is feasible, 0 otherwise
- Return type:
Tensor
- get_reward(td, actions)[source]¶
Function to compute the reward. Can be called by the agent to compute the reward of the current state This is faster than calling step() and getting the reward from the returned TensorDict at each time for CO tasks
- Return type:
TensorDict
- step(td)[source]¶
Step function to call at each step of the episode containing an action. If _torchrl_mode is True, we call _torchrl_step instead which set the next key of the TensorDict to the next state - this is the usual way to do it in TorchRL, but inefficient in our case
- Return type:
TensorDict
- to(device)[source]¶
Override to device method for safety against None device (may be found in TensorDict))
- transform()[source]¶
Used for converting TensorDict variables (such as with torch.cat) efficiently https://pytorch.org/rl/reference/generated/torchrl.envs.transforms.Transform.html By default, we do not need to transform the environment since we use specific embeddings
- batch_locked = False¶