Base Environment¶
- class rl4co.envs.common.base.RL4COEnvBase(*, data_dir='data/', train_file=None, val_file=None, test_file=None, check_solution=True, seed=None, device='cpu', **kwargs)[source]¶
Bases:
EnvBaseBase class for RL4CO environments based on TorchRL EnvBase
- Parameters:
data_dir¶ (
str) – Root directory for the datasettrain_file¶ (
Optional[str]) – Name of the training fileval_file¶ (
Optional[str]) – Name of the validation filetest_file¶ (
Optional[str]) – Name of the test filecheck_solution¶ (
bool) – Whether to check the validity of the solution at the end of the episodeseed¶ (
Optional[int]) – Seed for the environmentdevice¶ (
str) – Device to use. Generally, no need to set as tensors are updated on the fly
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- check_solution_validity(td, actions)[source]¶
Function to check whether the solution is valid. Can be called by the agent to check the validity of the current state This is called with the full solution (i.e. all actions) at the end of the episode
- Return type:
TensorDict
- dataset(batch_size=[], phase='train', filename=None)[source]¶
Return a dataset of observations Generates the dataset if it does not exist, otherwise loads it from file
- get_action_mask(td)[source]¶
Function to compute the action mask (feasible actions) for the current state Action mask is 1 if the action is feasible, 0 otherwise
- Return type:
Tensor
- get_reward(td, actions)[source]¶
Function to compute the reward. Can be called by the agent to compute the reward of the current state This is faster than calling step() and getting the reward from the returned TensorDict at each time for CO tasks
- Return type:
TensorDict
- transform()[source]¶
Used for converting TensorDict variables (such as with torch.cat) efficiently https://pytorch.org/rl/reference/generated/torchrl.envs.transforms.Transform.html By default, we do not need to transform the environment since we use specific embeddings
- batch_locked = False¶
- training: bool¶