Shortcuts

Base Environment

class rl4co.envs.common.base.RL4COEnvBase(*, data_dir='data/', train_file=None, val_file=None, test_file=None, check_solution=True, seed=None, device='cpu', **kwargs)[source]

Bases: EnvBase

Base class for RL4CO environments based on TorchRL EnvBase

Parameters:
  • data_dir (str) – Root directory for the dataset

  • train_file (Optional[str]) – Name of the training file

  • val_file (Optional[str]) – Name of the validation file

  • test_file (Optional[str]) – Name of the test file

  • check_solution (bool) – Whether to check the validity of the solution at the end of the episode

  • seed (Optional[int]) – Seed for the environment

  • device (str) – Device to use. Generally, no need to set as tensors are updated on the fly

Initializes internal Module state, shared by both nn.Module and ScriptModule.

check_solution_validity(td, actions)[source]

Function to check whether the solution is valid. Can be called by the agent to check the validity of the current state This is called with the full solution (i.e. all actions) at the end of the episode

Return type:

TensorDict

dataset(batch_size=[], phase='train', filename=None)[source]

Return a dataset of observations Generates the dataset if it does not exist, otherwise loads it from file

generate_data(batch_size)[source]

Dataset generation

get_action_mask(td)[source]

Function to compute the action mask (feasible actions) for the current state Action mask is 1 if the action is feasible, 0 otherwise

Return type:

Tensor

get_reward(td, actions)[source]

Function to compute the reward. Can be called by the agent to compute the reward of the current state This is faster than calling step() and getting the reward from the returned TensorDict at each time for CO tasks

Return type:

TensorDict

static load_data(fpath, batch_size=[])[source]

Dataset loading from file

render(*args, **kwargs)[source]

Render the environment

transform()[source]

Used for converting TensorDict variables (such as with torch.cat) efficiently https://pytorch.org/rl/reference/generated/torchrl.envs.transforms.Transform.html By default, we do not need to transform the environment since we use specific embeddings

batch_locked = False
training: bool