Shortcuts

Base Autoregressive Model

policy

Policy

class rl4co.models.zoo.common.autoregressive.policy.AutoregressivePolicy(env_name='tsp', encoder=None, decoder=None, init_embedding=None, context_embedding=None, dynamic_embedding=None, select_start_nodes_fn=<function select_start_nodes>, embedding_dim=128, num_encoder_layers=3, num_heads=8, normalization='batch', mask_inner=True, use_graph_context=True, sdpa_fn=None, train_decode_type='sampling', val_decode_type='greedy', test_decode_type='greedy', **unused_kw)[source]

Bases: Module

Base Auto-regressive policy for NCO construction methods. The policy performs the following steps:

  1. Encode the environment initial state into node embeddings

  2. Decode (autoregressively) to construct the solution to the NCO problem

Based on the policy from Kool et al. (2019) and extended for common use on multiple models in RL4CO.

Note

We recommend to provide the decoding method as a keyword argument to the decoder during actual testing. The {phase}_decode_type arguments are only meant to be used during the main training loop. You may have a look at the evaluation scripts for examples.

Parameters:
  • env_name ([<class 'str'>, <class 'rl4co.envs.common.base.RL4COEnvBase'>]) – Name of the environment used to initialize embeddings

  • encoder (Module) – Encoder module. Can be passed by sub-classes.

  • decoder (Module) – Decoder module. Can be passed by sub-classes.

  • init_embedding (Module) – Model to use for the initial embedding. If None, use the default embedding for the environment

  • context_embedding (Module) – Model to use for the context embedding. If None, use the default embedding for the environment

  • dynamic_embedding (Module) – Model to use for the dynamic embedding. If None, use the default embedding for the environment

  • select_start_nodes_fn (Callable) – Function to select the start nodes for multi-start decoding

  • embedding_dim (int) – Dimension of the node embeddings

  • num_encoder_layers (int) – Number of layers in the encoder

  • num_heads (int) – Number of heads in the attention layers

  • normalization (str) – Normalization type in the attention layers

  • mask_inner (bool) – Whether to mask the inner diagonal in the attention layers

  • use_graph_context (bool) – Whether to use the initial graph context to modify the query

  • sdpa_fn (Optional[Callable]) – Scaled dot product function to use for the attention

  • train_decode_type (str) – Type of decoding during training

  • val_decode_type (str) – Type of decoding during validation

  • test_decode_type (str) – Type of decoding during testing

  • **unused_kw – Unused keyword arguments

Initializes internal Module state, shared by both nn.Module and ScriptModule.

evaluate_action(td, action, env=None)[source]

Evaluate the action probability and entropy under the current policy

Parameters:
  • td (TensorDict) – TensorDict containing the current state

  • action (Tensor) – Action to evaluate

  • env (Union[str, RL4COEnvBase]) – Environment to evaluate the action in.

Return type:

Tuple[Tensor, Tensor]

forward(td, env=None, phase='train', return_actions=False, return_entropy=False, return_init_embeds=False, **decoder_kwargs)[source]

Forward pass of the policy.

Parameters:
Returns:

Dictionary containing the reward, log likelihood, and optionally the actions and entropy

Return type:

out

Encoder

class rl4co.models.zoo.common.autoregressive.encoder.GraphAttentionEncoder(env_name, num_heads, embedding_dim, num_layers, normalization='batch', feed_forward_hidden=512, init_embedding=None, sdpa_fn=None)[source]

Bases: Module

Graph Attention Encoder as in Kool et al. (2019).

Parameters:
  • env_name ([<class 'str'>, <class 'rl4co.envs.common.base.RL4COEnvBase'>]) – environment name to solve

  • num_heads (int) – Number of heads for the attention

  • embedding_dim (int) – Dimension of the embeddings

  • num_layers (int) – Number of layers for the encoder

  • normalization (str) – Normalization to use for the attention

  • feed_forward_hidden (int) – Hidden dimension for the feed-forward network

  • init_embedding (Module) – Model to use for the initial embedding. If None, use the default embedding for the environment

  • sdpa_fn – Scaled dot product function to use for the attention

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(td, mask=None)[source]

Forward pass of the encoder. Transform the input TensorDict into a latent representation.

Parameters:
  • td (TensorDict) – Input TensorDict containing the environment state

  • mask (Optional[Tensor]) – Mask to apply to the attention

Returns:

Latent representation of the input init_h: Initial embedding of the input

Return type:

h

Decoder

class rl4co.models.zoo.common.autoregressive.decoder.AutoregressiveDecoder(env_name, embedding_dim, num_heads, use_graph_context=True, select_start_nodes_fn=<function select_start_nodes>, linear_bias=False, context_embedding=None, dynamic_embedding=None, **logit_attn_kwargs)[source]

Bases: Module

Auto-regressive decoder for constructing solutions for combinatorial optimization problems. Given the environment state and the embeddings, compute the logits and sample actions autoregressively until all the environments in the batch have reached a terminal state. We additionally include support for multi-starts as it is more efficient to do so in the decoder as we can natively perform the attention computation.

Note

There are major differences between this decoding and most RL problems. The most important one is that reward is not defined for partial solutions, hence we have to wait for the environment to reach a terminal state before we can compute the reward with env.get_reward().

Warning

We suppose environments in the done state are still available for sampling. This is because in NCO we need to wait for all the environments to reach a terminal state before we can stop the decoding process. This is in contrast with the TorchRL framework (at the moment) where the env.rollout function automatically resets. You may follow tighter integration with TorchRL here: https://github.com/kaist-silab/rl4co/issues/72.

Parameters:
  • env_name ([<class 'str'>, <class 'rl4co.envs.common.base.RL4COEnvBase'>]) – environment name to solve

  • embedding_dim (int) – Dimension of the embeddings

  • num_heads (int) – Number of heads for the attention

  • use_graph_context (bool) – Whether to use the initial graph context to modify the query

  • select_start_nodes_fn (callable) – Function to select the start nodes for multi-start decoding

  • linear_bias (bool) – Whether to use a bias in the linear projection of the embeddings

  • context_embedding (Module) – Module to compute the context embedding. If None, the default is used

  • dynamic_embedding (Module) – Module to compute the dynamic embedding. If None, the default is used

Initializes internal Module state, shared by both nn.Module and ScriptModule.

evaluate_action(td, embeddings, action, env=None)[source]

Evaluate the (old) action to compute log likelihood of the actions and corresponding entropy :type _sphinx_paramlinks_rl4co.models.zoo.common.autoregressive.decoder.AutoregressiveDecoder.evaluate_action.td: TensorDict :param _sphinx_paramlinks_rl4co.models.zoo.common.autoregressive.decoder.AutoregressiveDecoder.evaluate_action.td: Input TensorDict containing the environment state :type _sphinx_paramlinks_rl4co.models.zoo.common.autoregressive.decoder.AutoregressiveDecoder.evaluate_action.embeddings: Tensor :param _sphinx_paramlinks_rl4co.models.zoo.common.autoregressive.decoder.AutoregressiveDecoder.evaluate_action.embeddings: Precomputed embeddings for the nodes :type _sphinx_paramlinks_rl4co.models.zoo.common.autoregressive.decoder.AutoregressiveDecoder.evaluate_action.action: Tensor :param _sphinx_paramlinks_rl4co.models.zoo.common.autoregressive.decoder.AutoregressiveDecoder.evaluate_action.action: Action to evaluate (batch_size, seq_len) :type _sphinx_paramlinks_rl4co.models.zoo.common.autoregressive.decoder.AutoregressiveDecoder.evaluate_action.env: Union[str, RL4COEnvBase] :param _sphinx_paramlinks_rl4co.models.zoo.common.autoregressive.decoder.AutoregressiveDecoder.evaluate_action.env: Environment to use for decoding. If None, the environment is instantiated from env_name. Note that

it is more efficient to pass an already instantiated environment each time for fine-grained control

Returns:

Tensor of shape (batch_size, seq_len, num_nodes) containing the log-likehood of the actions entropy: Tensor of shape (batch_size, seq_len) containing the sampled actions

Return type:

log_p

forward(td, embeddings, env=None, decode_type='sampling', num_starts=None, softmax_temp=None, calc_reward=True)[source]

Forward pass of the decoder Given the environment state and the pre-computed embeddings, compute the logits and sample actions

Parameters:
  • td (TensorDict) – Input TensorDict containing the environment state

  • embeddings (Tensor) – Precomputed embeddings for the nodes

  • env (Union[str, RL4COEnvBase]) – Environment to use for decoding. If None, the environment is instantiated from env_name. Note that it is more efficient to pass an already instantiated environment each time for fine-grained control

  • decode_type (str) – Type of decoding to use. Can be one of: - “sampling”: sample from the logits - “greedy”: take the argmax of the logits - “multistart_sampling”: sample as sampling, but with multi-start decoding - “multistart_greedy”: sample as greedy, but with multi-start decoding

  • num_starts (int) – Number of multi-starts to use. If None, will be calculated from the action mask

  • softmax_temp (float) – Temperature for the softmax. If None, default softmax is used from the LogitAttention module

  • calc_reward (bool) – Whether to calculate the reward for the decoded sequence

Returns:

Tensor of shape (batch_size, seq_len, num_nodes) containing the logits actions: Tensor of shape (batch_size, seq_len) containing the sampled actions td: TensorDict containing the environment state after decoding

Return type:

outputs

class rl4co.models.zoo.common.autoregressive.decoder.PrecomputedCache(node_embeddings, graph_context, glimpse_key, glimpse_val, logit_key)[source]

Bases: object

glimpse_key: Tensor
glimpse_val: Tensor
graph_context: Union[Tensor, float]
logit_key: Tensor
node_embeddings: Tensor