Shortcuts

Model Zoo

Models from the literature and contributions are contained in the Model Zoo.


Auto-Regressive Models

Attention Model (AM)

class rl4co.models.zoo.am.model.AttentionModel(env, policy=None, baseline='rollout', policy_kwargs={}, baseline_kwargs={}, **kwargs)[source]

Bases: REINFORCE

Attention Model based on REINFORCE: https://arxiv.org/abs/1803.08475.

Parameters:
  • env (RL4COEnvBase) – Environment to use for the algorithm

  • policy (AttentionModelPolicy) – Policy to use for the algorithm

  • baseline (Union[REINFORCEBaseline, str]) – REINFORCE baseline. Defaults to rollout (1 epoch of exponential, then greedy rollout baseline)

  • policy_kwargs – Keyword arguments for policy

  • baseline_kwargs – Keyword arguments for baseline

  • **kwargs – Keyword arguments passed to the superclass

class rl4co.models.zoo.am.policy.AttentionModelPolicy(env_name, embedding_dim=128, num_encoder_layers=3, num_heads=8, normalization='batch', **kwargs)[source]

Bases: AutoregressivePolicy

Attention Model Policy based on Kool et al. (2019): https://arxiv.org/abs/1803.08475. We re-declare the most important arguments here for convenience as in the paper. See AutoregressivePolicy superclass for more details.

Parameters:
  • env_name (str) – Name of the environment used to initialize embeddings

  • embedding_dim (int) – Dimension of the node embeddings

  • num_encoder_layers (int) – Number of layers in the encoder

  • num_heads (int) – Number of heads in the attention layers

  • normalization (str) – Normalization type in the attention layers

  • **kwargs – keyword arguments passed to the AutoregressivePolicy

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Attention Model (AM-PPO)

class rl4co.models.zoo.ppo.model.PPOModel(env, policy=None, critic=None, policy_kwargs={}, critic_kwargs={}, **kwargs)[source]

Bases: PPO

PPO Model based on Proximal Policy Optimization (PPO).

Parameters:
  • env (RL4COEnvBase) – Environment to use for the algorithm

  • policy (PPOPolicy) – Policy to use for the algorithm

  • critic (CriticNetwork) – Critic to use for the algorithm

  • policy_kwargs (dict) – Keyword arguments for policy

  • critic_kwargs (dict) – Keyword arguments for critic

class rl4co.models.zoo.ppo.policy.PPOPolicy(env_name, embedding_dim=128, num_encoder_layers=3, num_heads=8, normalization='batch', **kwargs)[source]

Bases: AutoregressivePolicy

PPO Policy. The backbone model is inspired by the Kool et al. (2019): https://arxiv.org/abs/1803.08475. This is simply a wrapper around the AutoregressivePolicy class. PPO needs an evaluate_actions method inside AutoregressivePolicy to work properly to obtain log probabilities and entropy of actions under the current policy.

Parameters:
  • env_name (str) – Name of the environment used to initialize embeddings

  • embedding_dim (int) – Dimension of the node embeddings

  • num_encoder_layers (int) – Number of layers in the encoder

  • num_heads (int) – Number of heads in the attention layers

  • normalization (str) – Normalization type in the attention layers

  • **kwargs – keyword arguments passed to the AutoregressivePolicy

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Heterogeneous Attention Model (HAM)

class rl4co.models.zoo.ham.model.HeterogeneousAttentionModel(env, policy=None, baseline='rollout', policy_kwargs={}, baseline_kwargs={}, **kwargs)[source]

Bases: REINFORCE

Heterogenous Attention Model for solving the Pickup and Delivery Problem based on REINFORCE: https://arxiv.org/abs/2110.02634.

Parameters:
  • env (RL4COEnvBase) – Environment to use for the algorithm

  • policy (HeterogeneousAttentionModelPolicy) – Policy to use for the algorithm

  • baseline (Union[REINFORCEBaseline, str]) – REINFORCE baseline. Defaults to rollout (1 epoch of exponential, then greedy rollout baseline)

  • policy_kwargs – Keyword arguments for policy

  • baseline_kwargs – Keyword arguments for baseline

  • **kwargs – Keyword arguments passed to the superclass

class rl4co.models.zoo.ham.policy.HeterogeneousAttentionModelPolicy(env_name, embedding_dim=128, num_encoder_layers=3, num_heads=8, normalization='batch', **kwargs)[source]

Bases: AutoregressivePolicy

Heterogeneous Attention Model Policy based on Kool et al. (2019): https://arxiv.org/abs/1803.08475. We re-declare the most important arguments here for convenience as in the paper. See AutoregressivePolicy superclass for more details.

Parameters:
  • env_name (str) – Name of the environment used to initialize embeddings

  • encoder – Encoder to use for the policy

  • embedding_dim (int) – Dimension of the node embeddings

  • num_encoder_layers (int) – Number of layers in the encoder

  • num_heads (int) – Number of heads in the attention layers

  • normalization (str) – Normalization type in the attention layers

  • **kwargs – keyword arguments passed to the AutoregressivePolicy

Initializes internal Module state, shared by both nn.Module and ScriptModule.

class rl4co.models.zoo.ham.encoder.GraphHeterogeneousAttentionEncoder(num_heads=8, embedding_dim=128, num_encoder_layers=3, env_name=None, normalization='batch', feed_forward_hidden=512, sdpa_fn=None)[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x, mask=None)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class rl4co.models.zoo.ham.encoder.HeterogeneuousMHALayer(num_heads, embed_dim, feed_forward_hidden=512, normalization='batch')[source]

Bases: Sequential

Initializes internal Module state, shared by both nn.Module and ScriptModule.

class rl4co.models.zoo.ham.attention.HeterogenousMHA(num_heads, input_dim, embed_dim=None, val_dim=None, key_dim=None)[source]

Bases: Module

Heterogenous Multi-Head Attention for Pickup and Delivery problems https://arxiv.org/abs/2110.02634

forward(q, h=None, mask=None)[source]
Parameters:
  • q – queries (batch_size, n_query, input_dim)

  • h – data (batch_size, graph_size, input_dim)

  • mask – mask (batch_size, n_query, graph_size) or viewable as that (i.e. can be 2 dim if n_query == 1)

  • possible (Mask _sphinx_paramlinks_rl4co.models.zoo.ham.attention.HeterogenousMHA.forward.should contain 1 if attention is not) –

init_parameters()[source]

Matrix Encoding Network (MatNet)

class rl4co.models.zoo.matnet.model.MatNet(env, policy=None, optimizer_kwargs={'lr': 0.0004, 'weight_decay': 1e-06}, lr_scheduler='MultiStepLR', lr_scheduler_kwargs={'gamma': 0.1, 'milestones': [2001, 2101]}, use_dihedral_8=False, num_starts=None, train_data_size=10000, batch_size=200, policy_params={}, model_params={})[source]

Bases: POMO

class rl4co.models.zoo.matnet.policy.MatNetPolicy(env_name, embedding_dim=256, num_encoder_layers=5, num_heads=16, normalization='instance', init_embedding_kwargs={'mode': 'RandomOneHot'}, use_graph_context=False, **kwargs)[source]

Bases: AutoregressivePolicy

MatNet Policy from Kwon et al., 2021. Reference: https://arxiv.org/abs/2106.11113

Warning

This implementation is under development and subject to change.

Parameters:
  • env_name (str) – Name of the environment used to initialize embeddings

  • embedding_dim (int) – Dimension of the node embeddings

  • num_encoder_layers (int) – Number of layers in the encoder

  • num_heads (int) – Number of heads in the attention layers

  • normalization (str) – Normalization type in the attention layers

  • **kwargs – keyword arguments passed to the AutoregressivePolicy

Default paarameters are adopted from the original implementation.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

class rl4co.models.zoo.matnet.encoder.MatNetATSPInitEmbedding(embedding_dim, mode='RandomOneHot')[source]

Bases: Module

Preparing the initial row and column embeddings for ATSP.

Reference: https://github.com/yd-kwon/MatNet/blob/782698b60979effe2e7b61283cca155b7cdb727f/ATSP/ATSP_MatNet/ATSPModel.py#L51

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(td)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class rl4co.models.zoo.matnet.encoder.MatNetCrossMHA(embedding_dim, num_heads, bias=True, mixer_hidden_dim=16, mix1_init=0.7071067811865476, mix2_init=0.25)[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(q_input, kv_input, dmat)[source]
Parameters:
  • q_input (Tensor) – [b, m, d]

  • kv_input (Tensor) – [b, n, d]

  • dmat (Tensor) – [b, m, n]

Returns:

[b, m, d]

Return type:

Tensor

class rl4co.models.zoo.matnet.encoder.MatNetEncoder(embedding_dim=256, num_heads=16, num_layers=5, normalization='instance', feed_forward_hidden=512, init_embedding=None, init_embedding_kwargs=None)[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(td)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class rl4co.models.zoo.matnet.encoder.MatNetMHA(embedding_dim, num_heads, bias=True)[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(row_emb, col_emb, dmat)[source]
Parameters:
  • row_emb (Tensor) – [b, m, d]

  • col_emb (Tensor) – [b, n, d]

  • dmat (Tensor) – [b, m, n]

Returns:

[b, m, d] Updated col_emb (Tensor): [b, n, d]

Return type:

Updated row_emb (Tensor)

class rl4co.models.zoo.matnet.encoder.MatNetMHALayer(embedding_dim, num_heads, bias=True, feed_forward_hidden=512, normalization='instance')[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(row_emb, col_emb, dmat)[source]
Parameters:
  • row_emb (Tensor) – [b, m, d]

  • col_emb (Tensor) – [b, n, d]

  • dmat (Tensor) – [b, m, n]

Returns:

[b, m, d] Updated col_emb (Tensor): [b, n, d]

Return type:

Updated row_emb (Tensor)

class rl4co.models.zoo.matnet.encoder.MatNetMHANetwork(embedding_dim=128, num_heads=8, num_layers=3, normalization='batch', feed_forward_hidden=512)[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(row_emb, col_emb, dmat)[source]
Parameters:
  • row_emb (Tensor) – [b, m, d]

  • col_emb (Tensor) – [b, n, d]

  • dmat (Tensor) – [b, m, n]

Returns:

[b, m, d] Updated col_emb (Tensor): [b, n, d]

Return type:

Updated row_emb (Tensor)

class rl4co.models.zoo.matnet.decoder.MatNetDecoder(env_name, embedding_dim, num_heads, use_graph_context=True, select_start_nodes_fn=<function select_start_nodes>, linear_bias=False, context_embedding=None, dynamic_embedding=None, **logit_attn_kwargs)[source]

Bases: AutoregressiveDecoder

Initializes internal Module state, shared by both nn.Module and ScriptModule.

class rl4co.models.zoo.matnet.decoder.PrecomputedCache(node_embeddings, graph_context, glimpse_key, glimpse_val, logit_key)[source]

Bases: object

glimpse_key: Tensor
glimpse_val: Tensor
graph_context: Union[Tensor, float]
logit_key: Tensor
node_embeddings: Tensor

Multi-Decoder Attention Model (MDAM)

class rl4co.models.zoo.mdam.model.MDAM(env, policy=None, baseline='rollout', policy_kwargs={}, baseline_kwargs={}, **kwargs)[source]

Bases: REINFORCE

Multi-Decoder Attention Model (MDAM) is a model to train multiple diverse policies, which effectively increases the chance of finding good solutions compared with existing methods that train only one policy. Reference link: https://arxiv.org/abs/2012.10638; Implementation reference: https://github.com/liangxinedu/MDAM.

Parameters:
  • env (RL4COEnvBase) – Environment to use for the algorithm

  • policy (MDAMPolicy) – Policy to use for the algorithm

  • baseline (Union[REINFORCEBaseline, str]) – REINFORCE baseline. Defaults to rollout (1 epoch of exponential, then greedy rollout baseline)

  • policy_kwargs – Keyword arguments for policy

  • baseline_kwargs – Keyword arguments for baseline

  • **kwargs – Keyword arguments passed to the superclass

class rl4co.models.zoo.mdam.policy.MDAMPolicy(env_name, embedding_dim=128, num_encoder_layers=3, num_heads=8, normalization='batch', **kwargs)[source]

Bases: AutoregressivePolicy

Multi-Decoder Attention Model (MDAM) policy. Args:

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(td, env=None, phase='train', return_actions=False, **decoder_kwargs)[source]

Forward pass of the policy.

Parameters:
Returns:

Dictionary containing the reward, log likelihood, and optionally the actions and entropy

Return type:

out

class rl4co.models.zoo.mdam.encoder.GraphAttentionEncoder(num_heads, embed_dim, num_layers, node_dim=None, normalization='batch', feed_forward_hidden=512, sdpa_fn=None)[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

change(attn, V, h_old, mask, is_tsp=False)[source]
forward(x, mask=None, return_transform_loss=False)[source]
Returns:

  • h [batch_size, graph_size, embed_dim]

  • attn [num_head, batch_size, graph_size, graph_size]

  • V [num_head, batch_size, graph_size, key_dim]

  • h_old [batch_size, graph_size, embed_dim]

class rl4co.models.zoo.mdam.decoder.Decoder(env_name, embedding_dim, num_heads, num_paths=5, mask_inner=True, mask_logits=True, eg_step_gap=200, tanh_clipping=10.0, force_flash_attn=False, shrink_size=None, train_decode_type='sampling', val_decode_type='greedy', test_decode_type='greedy')[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(td, encoded_inputs, env, attn, V, h_old, **decoder_kwargs)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class rl4co.models.zoo.mdam.decoder.PrecomputedCache(node_embeddings, graph_context, glimpse_key, glimpse_val, logit_key)[source]

Bases: object

glimpse_key: Tensor
glimpse_val: Tensor
graph_context: Tensor
logit_key: Tensor
node_embeddings: Tensor

POMO

class rl4co.models.zoo.pomo.model.POMO(env, policy=None, policy_kwargs={}, baseline='shared', num_augment=8, use_dihedral_8=True, num_starts=None, select_start_nodes_fn=<function select_start_nodes>, **kwargs)[source]

Bases: REINFORCE

POMO Model for neural combinatorial optimization based on REINFORCE Based on Kwon et al. (2020) http://arxiv.org/abs/2010.16011.

Parameters:
  • env (RL4COEnvBase) – TorchRL Environment

  • policy (Union[Module, POMOPolicy]) – Policy to use for the algorithm

  • policy_kwargs – Keyword arguments for policy

  • baseline (str) – Baseline to use for the algorithm. Note that POMO only supports shared baseline, so we will throw an error if anything else is passed.

  • num_augment (int) – Number of augmentations (used only for validation and test)

  • use_dihedral_8 (bool) – Whether to use dihedral 8 augmentation

  • num_starts (int) – Number of starts for multi-start. If None, use the number of available actions

  • select_start_nodes_fn (callable) – Function to select the start nodes for the environment defaulting to select_start_nodes()

  • **kwargs – Keyword arguments passed to the superclass

shared_step(batch, batch_idx, phase, dataloader_idx=None)[source]

Shared step between train/val/test. To be implemented in subclass

class rl4co.models.zoo.pomo.policy.POMOPolicy(env_name, embedding_dim=128, num_encoder_layers=6, num_heads=8, normalization='instance', use_graph_context=False, select_start_nodes_fn=<function select_start_nodes>, **kwargs)[source]

Bases: AutoregressivePolicy

POMO model policy based on Kwon et al. (2020) http://arxiv.org/abs/2010.16011. We re-declare the most important arguments here for convenience as in the paper. See AutoregressivePolicy superclass for more details.

Note

Although the policy is the base AutoregressivePolicy, we use the default values used in the paper. Differently to the base class: - num_encoder_layers=6 (instead of 3) - normalization=”instance” (instead of “batch”) - use_graph_context=False (instead of True) The latter is due to the fact that the paper does not use the graph context in the policy, which seems to be helpful in overfitting to the training graph size.

Parameters:
  • env_name (str) – Name of the environment used to initialize embeddings

  • embedding_dim (int) – Dimension of the node embeddings

  • num_encoder_layers (int) – Number of layers in the encoder

  • num_heads (int) – Number of heads in the attention layers

  • normalization (str) – Normalization type in the attention layers

  • select_start_nodes_fn (callable) – Function to select the start nodes for the environment defaulting to select_start_nodes()

  • **kwargs – keyword arguments passed to the AutoregressivePolicy

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Pointer Network (PtrNet)

class rl4co.models.zoo.ptrnet.model.PointerNetwork(env, policy=None, baseline='rollout', policy_kwargs={}, baseline_kwargs={}, **kwargs)[source]

Bases: REINFORCE

Pointer Network for neural combinatorial optimization based on REINFORCE Based on Vinyals et al. (2015) https://arxiv.org/abs/1506.03134 Refactored from reference implementation: https://github.com/wouterkool/attention-learn-to-route

Parameters:
  • env (RL4COEnvBase) – Environment to use for the algorithm

  • policy (PointerNetworkPolicy) – Policy to use for the algorithm

  • baseline (Union[REINFORCEBaseline, str]) – REINFORCE baseline. Defaults to rollout (1 epoch of exponential, then greedy rollout baseline)

  • policy_kwargs – Keyword arguments for policy

  • baseline_kwargs – Keyword arguments for baseline

  • **kwargs – Keyword arguments passed to the superclass

class rl4co.models.zoo.ptrnet.policy.PointerNetworkPolicy(env_name, embedding_dim=128, hidden_dim=128, tanh_clipping=10.0, mask_inner=True, mask_logits=True, **kwargs)[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(td, env, phase='train', decode_type='sampling', eval_tours=None, **unused_kwargs)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class rl4co.models.zoo.ptrnet.encoder.Encoder(input_dim, hidden_dim)[source]

Bases: Module

Maps a graph represented as an input sequence to a hidden vector

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x, hidden)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

init_hidden(hidden_dim)[source]

Trainable initial hidden state

class rl4co.models.zoo.ptrnet.decoder.Decoder(embedding_dim=128, hidden_dim=128, tanh_exploration=10.0, use_tanh=True, num_glimpses=1, mask_glimpses=True, mask_logits=True)[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

calc_logits(x, h_in, logit_mask, context, mask_glimpses=None, mask_logits=None)[source]
forward(decoder_input, embedded_inputs, hidden, context, decode_type='sampling', eval_tours=None)[source]
Parameters:
  • decoder_input – The initial input to the decoder size is [batch_size x embedding_dim]. Trainable parameter.

  • embedded_inputs – [sourceL x batch_size x embedding_dim]

  • hidden – the prev hidden state, size is [batch_size x hidden_dim]. Initially this is set to (enc_h[-1], enc_c[-1])

  • context – encoder outputs, [sourceL x batch_size x hidden_dim]

recurrence(x, h_in, prev_mask, prev_idxs, step, context)[source]
update_mask(mask, selected)[source]
class rl4co.models.zoo.ptrnet.decoder.SimpleAttention(dim, use_tanh=False, C=10)[source]

Bases: Module

A generic attention module for a decoder in seq2seq

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(query, ref)[source]
Parameters:
  • query – is the hidden state of the decoder at the current time step. batch x dim

  • ref – the set of hidden states from the encoder. sourceL x batch x hidden_dim

class rl4co.models.zoo.ptrnet.critic.CriticNetworkLSTM(embedding_dim, hidden_dim, n_process_block_iters, tanh_exploration, use_tanh)[source]

Bases: Module

Useful as a baseline in REINFORCE updates

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(inputs)[source]
Parameters:

inputs – [embedding_dim x batch_size x sourceL] of embedded inputs

SymNCO

class rl4co.models.zoo.symnco.model.SymNCO(env, policy=None, policy_kwargs={}, baseline='symnco', num_augment=4, alpha=0.2, beta=1, num_starts=0, **kwargs)[source]

Bases: REINFORCE

SymNCO Model based on REINFORCE with shared baselines. Based on Kim et al. (2022) https://arxiv.org/abs/2205.13209.

Parameters:
  • env (RL4COEnvBase) – TorchRL environment to use for the algorithm

  • policy (Union[Module, SymNCOPolicy]) – Policy to use for the algorithm

  • policy_kwargs (dict) – Keyword arguments for policy

  • num_augment (int) – Number of augmentations

  • alpha (float) – weight for invariance loss

  • beta (float) – weight for solution symmetricity loss

  • num_starts (int) – Number of starts for multi-start. If None, use the number of available actions

  • **kwargs – Keyword arguments passed to the superclass

shared_step(batch, batch_idx, phase, dataloader_idx=None)[source]

Shared step between train/val/test. To be implemented in subclass

class rl4co.models.zoo.symnco.policy.SymNCOPolicy(env_name, embedding_dim=128, num_encoder_layers=3, num_heads=8, normalization='batch', projection_head=None, use_projection_head=True, **kwargs)[source]

Bases: AutoregressivePolicy

SymNCO Policy based on AutoregressivePolicy. This differs from the default AutoregressivePolicy in that it projects the initial embeddings to a lower dimension using a projection head and returns it. This is used in the SymNCO algorithm to compute the invariance loss. Based on Kim et al. (2022) https://arxiv.org/abs/2205.13209.

Parameters:
  • env_name (str) – Name of the environment

  • embedding_dim (int) – Dimension of the embedding

  • num_encoder_layers (int) – Number of layers in the encoder

  • num_heads (int) – Number of heads in the encoder

  • normalization (str) – Normalization to use in the encoder

  • projection_head (Module) – Projection head to use

  • use_projection_head (bool) – Whether to use projection head

  • **kwargs – Keyword arguments passed to the superclass

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(td, env=None, phase='train', return_actions=False, return_entropy=False, return_init_embeds=True, **decoder_kwargs)[source]

Forward pass of the policy.

Parameters:
Returns:

Dictionary containing the reward, log likelihood, and optionally the actions and entropy

Return type:

out

rl4co.models.zoo.symnco.losses.invariance_loss(proj_embed, num_augment)[source]

Loss for invariant representation on projected nodes Corresponds to L_inv in the SymNCO paper

rl4co.models.zoo.symnco.losses.problem_symmetricity_loss(reward, log_likelihood, dim=1)[source]

REINFORCE loss for problem symmetricity Baseline is the average reward for all augmented problems Corresponds to L_ps in the SymNCO paper

rl4co.models.zoo.symnco.losses.solution_symmetricity_loss(reward, log_likelihood, dim=-1)[source]

REINFORCE loss for solution symmetricity Baseline is the average reward for all start nodes Corresponds to L_ss in the SymNCO paper

Search Methods

Active Search (AS)

class rl4co.models.zoo.active_search.search.ActiveSearch(env, policy, dataset, batch_size=1, max_iters=200, augment_size=8, augment_dihedral=True, num_parallel_runs=1, max_runtime=86400, save_path=None, optimizer='Adam', optimizer_kwargs={'lr': 0.00026, 'weight_decay': 1e-06}, **kwargs)[source]

Bases: SearchBase

Active Search for Neural Combination Optimization from Bello et al. (2016). Fine-tunes the whole policy network (encoder + decoder) on a batch of instances. Reference: https://arxiv.org/abs/1611.09940

Parameters:
  • env – RL4CO environment to be solved

  • policy – policy network

  • dataset (Union[Dataset, str]) – dataset to be used for training

  • batch_size (int) – batch size for training

  • max_iters (int) – maximum number of iterations

  • augment_size (int) – number of augmentations per state

  • augment_dihedral (bool) – whether to augment with dihedral rotations

  • parallel_runs – number of parallel runs

  • max_runtime (int) – maximum runtime in seconds

  • save_path (str) – path to save solution checkpoints

  • optimizer (Union[str, Optimizer, partial]) – optimizer to use for training

  • optimizer_kwargs (dict) – keyword arguments for optimizer

  • **kwargs – additional keyword arguments

on_train_batch_end(outputs, batch, batch_idx)[source]

We store the best solution and reward found.

Return type:

None

on_train_batch_start(batch, batch_idx)[source]

Called before training (i.e. search) for a new batch begins. We re-load the original policy state dict and configure the optimizer.

on_train_epoch_end()[source]

Called when the training ends. If the epoch ends, it means we have finished searching over the instances, thus the trainer should stop.

Return type:

None

setup(stage='fit')[source]

Setup base class and instantiate: - augmentation - instance solutions and rewards - original policy state dict

training_step(batch, batch_idx)[source]

Main search loop. We use the training step to effectively adapt to a batch of instances.

Efficent Active Search (EAS)

class rl4co.models.zoo.eas.search.EAS(env, policy, dataset, use_eas_embedding=True, use_eas_layer=False, eas_emb_cache_keys=['logit_key'], eas_lambda=0.013, batch_size=2, max_iters=200, augment_size=8, augment_dihedral=True, num_parallel_runs=1, baseline='multistart', max_runtime=86400, save_path=None, optimizer='Adam', optimizer_kwargs={'lr': 0.0041, 'weight_decay': 1e-06}, verbose=True, **kwargs)[source]

Bases: SearchBase

Efficient Active Search for Neural Combination Optimization from Hottung et al. (2022). Fine-tunes a subset of parameters (such as node embeddings or newly added layers) thus avoiding expensive re-encoding of the problem. Reference: https://openreview.net/pdf?id=nO5caZwFwYu

Parameters:
  • env – RL4CO environment to be solved

  • policy – policy network

  • dataset (Union[Dataset, str]) – dataset to be used for training

  • use_eas_embedding (bool) – whether to use EAS embedding (EASEmb)

  • use_eas_layer (bool) – whether to use EAS layer (EASLay)

  • eas_emb_cache_keys (List[str]) – keys to cache in the embedding

  • eas_lambda (float) – lambda parameter for IL loss

  • batch_size (int) – batch size for training

  • max_iters (int) – maximum number of iterations

  • augment_size (int) – number of augmentations per state

  • augment_dihedral (bool) – whether to augment with dihedral rotations

  • parallel_runs – number of parallel runs

  • baseline (str) – REINFORCE baseline type (multistart, symmetric, full)

  • max_runtime (int) – maximum runtime in seconds

  • save_path (str) – path to save solution checkpoints

  • optimizer (Union[str, Optimizer, partial]) – optimizer to use for training

  • optimizer_kwargs (dict) – keyword arguments for optimizer

  • verbose (bool) – whether to print progress for each iteration

on_train_batch_end(outputs, batch, batch_idx)[source]

We store the best solution and reward found.

Return type:

None

on_train_batch_start(batch, batch_idx)[source]

Called before training (i.e. search) for a new batch begins. We re-load the original policy state dict and configure all parameters not to require gradients. We do the rest in the training step.

on_train_epoch_end()[source]

Called when the train ends.

Return type:

None

setup(stage='fit')[source]

Setup base class and instantiate: - augmentation - instance solutions and rewards - original policy state dict

training_step(batch, batch_idx)[source]

Main search loop. We use the training step to effectively adapt to a batch of instances.

class rl4co.models.zoo.eas.search.EASEmb(*args, **kwargs)[source]

Bases: EAS

EAS with embedding adaptation

class rl4co.models.zoo.eas.search.EASLay(*args, **kwargs)[source]

Bases: EAS

EAS with layer adaptation

rl4co.models.zoo.eas.decoder.forward_eas(self, td, cached_embeds, best_solutions, iter_count=0, env=None, decode_type='sampling_multistart', num_starts=None, softmax_temp=None, **unused_kwargs)[source]

Forward pass of the decoder Given the environment state and the pre-computed embeddings, compute the logits and sample actions

Parameters:
  • td (TensorDict) – Input TensorDict containing the environment state

  • embeddings – Precomputed embeddings for the nodes. Can be already precomputed cached in form of q, k, v and

  • env (Union[str, RL4COEnvBase]) – Environment to use for decoding. If None, the environment is instantiated from env_name. Note that it is more efficient to pass an already instantiated environment each time for fine-grained control

  • decode_type (str) – Type of decoding to use. Can be one of: - “sampling”: sample from the logits - “greedy”: take the argmax of the logits - “multistart_sampling”: sample as sampling, but with multi-start decoding - “multistart_greedy”: sample as greedy, but with multi-start decoding

  • num_starts (int) – Number of multi-starts to use. If None, will be calculated from the action mask

  • softmax_temp (float) – Temperature for the softmax. If None, default softmax is used from the LogitAttention module

  • calc_reward – Whether to calculate the reward for the decoded sequence

rl4co.models.zoo.eas.decoder.forward_logit_attn_eas_lay(self, query, key, value, logit_key, mask, softmax_temp=None)[source]

Add layer to the forward pass of logit attention, i.e. Single-head attention.

class rl4co.models.zoo.eas.nn.EASLayerNet(num_instances, emb_dim)[source]

Bases: Module

Instantiate weights and biases for the added layer. The layer is defined as: h = relu(emb * W1 + b1); out = h * W2 + b2. Wrapping in nn.Parameter makes the parameters trainable and sets gradient to True.

Parameters:
  • num_instances (int) – Number of instances in the dataset

  • emb_dim (int) – Dimension of the embedding

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(*args)[source]

emb: [num_instances, group_num, emb_dim]