Model Zoo¶
Models from the literature and contributions are contained in the Model Zoo.
Auto-Regressive Models¶
Attention Model (AM)¶
- class rl4co.models.zoo.am.model.AttentionModel(env, policy=None, baseline='rollout', policy_kwargs={}, baseline_kwargs={}, **kwargs)[source]¶
Bases:
REINFORCEAttention Model based on REINFORCE: https://arxiv.org/abs/1803.08475.
- Parameters:
env¶ (
RL4COEnvBase) – Environment to use for the algorithmpolicy¶ (
AttentionModelPolicy) – Policy to use for the algorithmbaseline¶ (
Union[REINFORCEBaseline,str]) – REINFORCE baseline. Defaults to rollout (1 epoch of exponential, then greedy rollout baseline)policy_kwargs¶ – Keyword arguments for policy
baseline_kwargs¶ – Keyword arguments for baseline
**kwargs¶ – Keyword arguments passed to the superclass
- class rl4co.models.zoo.am.policy.AttentionModelPolicy(env_name, embedding_dim=128, num_encoder_layers=3, num_heads=8, normalization='batch', **kwargs)[source]¶
Bases:
AutoregressivePolicyAttention Model Policy based on Kool et al. (2019): https://arxiv.org/abs/1803.08475. We re-declare the most important arguments here for convenience as in the paper. See AutoregressivePolicy superclass for more details.
- Parameters:
env_name¶ (
str) – Name of the environment used to initialize embeddingsembedding_dim¶ (
int) – Dimension of the node embeddingsnum_encoder_layers¶ (
int) – Number of layers in the encodernum_heads¶ (
int) – Number of heads in the attention layersnormalization¶ (
str) – Normalization type in the attention layers**kwargs¶ – keyword arguments passed to the AutoregressivePolicy
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Attention Model (AM-PPO)¶
- class rl4co.models.zoo.ppo.model.PPOModel(env, policy=None, critic=None, policy_kwargs={}, critic_kwargs={}, **kwargs)[source]¶
Bases:
PPOPPO Model based on Proximal Policy Optimization (PPO).
- Parameters:
env¶ (
RL4COEnvBase) – Environment to use for the algorithmcritic¶ (
CriticNetwork) – Critic to use for the algorithmpolicy_kwargs¶ (
dict) – Keyword arguments for policycritic_kwargs¶ (
dict) – Keyword arguments for critic
- class rl4co.models.zoo.ppo.policy.PPOPolicy(env_name, embedding_dim=128, num_encoder_layers=3, num_heads=8, normalization='batch', **kwargs)[source]¶
Bases:
AutoregressivePolicyPPO Policy. The backbone model is inspired by the Kool et al. (2019): https://arxiv.org/abs/1803.08475. This is simply a wrapper around the AutoregressivePolicy class. PPO needs an evaluate_actions method inside AutoregressivePolicy to work properly to obtain log probabilities and entropy of actions under the current policy.
- Parameters:
env_name¶ (
str) – Name of the environment used to initialize embeddingsembedding_dim¶ (
int) – Dimension of the node embeddingsnum_encoder_layers¶ (
int) – Number of layers in the encodernum_heads¶ (
int) – Number of heads in the attention layersnormalization¶ (
str) – Normalization type in the attention layers**kwargs¶ – keyword arguments passed to the AutoregressivePolicy
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Heterogeneous Attention Model (HAM)¶
- class rl4co.models.zoo.ham.model.HeterogeneousAttentionModel(env, policy=None, baseline='rollout', policy_kwargs={}, baseline_kwargs={}, **kwargs)[source]¶
Bases:
REINFORCEHeterogenous Attention Model for solving the Pickup and Delivery Problem based on REINFORCE: https://arxiv.org/abs/2110.02634.
- Parameters:
env¶ (
RL4COEnvBase) – Environment to use for the algorithmpolicy¶ (
HeterogeneousAttentionModelPolicy) – Policy to use for the algorithmbaseline¶ (
Union[REINFORCEBaseline,str]) – REINFORCE baseline. Defaults to rollout (1 epoch of exponential, then greedy rollout baseline)policy_kwargs¶ – Keyword arguments for policy
baseline_kwargs¶ – Keyword arguments for baseline
**kwargs¶ – Keyword arguments passed to the superclass
- class rl4co.models.zoo.ham.policy.HeterogeneousAttentionModelPolicy(env_name, embedding_dim=128, num_encoder_layers=3, num_heads=8, normalization='batch', **kwargs)[source]¶
Bases:
AutoregressivePolicyHeterogeneous Attention Model Policy based on Kool et al. (2019): https://arxiv.org/abs/1803.08475. We re-declare the most important arguments here for convenience as in the paper. See AutoregressivePolicy superclass for more details.
- Parameters:
env_name¶ (
str) – Name of the environment used to initialize embeddingsencoder¶ – Encoder to use for the policy
embedding_dim¶ (
int) – Dimension of the node embeddingsnum_encoder_layers¶ (
int) – Number of layers in the encodernum_heads¶ (
int) – Number of heads in the attention layersnormalization¶ (
str) – Normalization type in the attention layers**kwargs¶ – keyword arguments passed to the AutoregressivePolicy
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- class rl4co.models.zoo.ham.encoder.GraphHeterogeneousAttentionEncoder(num_heads=8, embedding_dim=128, num_encoder_layers=3, env_name=None, normalization='batch', feed_forward_hidden=512, sdpa_fn=None)[source]¶
Bases:
ModuleInitializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(x, mask=None)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class rl4co.models.zoo.ham.encoder.HeterogeneuousMHALayer(num_heads, embed_dim, feed_forward_hidden=512, normalization='batch')[source]¶
Bases:
SequentialInitializes internal Module state, shared by both nn.Module and ScriptModule.
- class rl4co.models.zoo.ham.attention.HeterogenousMHA(num_heads, input_dim, embed_dim=None, val_dim=None, key_dim=None)[source]¶
Bases:
ModuleHeterogenous Multi-Head Attention for Pickup and Delivery problems https://arxiv.org/abs/2110.02634
- forward(q, h=None, mask=None)[source]¶
- Parameters:
q¶ – queries (batch_size, n_query, input_dim)
h¶ – data (batch_size, graph_size, input_dim)
mask¶ – mask (batch_size, n_query, graph_size) or viewable as that (i.e. can be 2 dim if n_query == 1)
possible (Mask _sphinx_paramlinks_rl4co.models.zoo.ham.attention.HeterogenousMHA.forward.should contain 1 if attention is not) –
Matrix Encoding Network (MatNet)¶
- class rl4co.models.zoo.matnet.model.MatNet(env, policy=None, optimizer_kwargs={'lr': 0.0004, 'weight_decay': 1e-06}, lr_scheduler='MultiStepLR', lr_scheduler_kwargs={'gamma': 0.1, 'milestones': [2001, 2101]}, use_dihedral_8=False, num_starts=None, train_data_size=10000, batch_size=200, policy_params={}, model_params={})[source]¶
Bases:
POMO
- class rl4co.models.zoo.matnet.policy.MatNetPolicy(env_name, embedding_dim=256, num_encoder_layers=5, num_heads=16, normalization='instance', init_embedding_kwargs={'mode': 'RandomOneHot'}, use_graph_context=False, **kwargs)[source]¶
Bases:
AutoregressivePolicyMatNet Policy from Kwon et al., 2021. Reference: https://arxiv.org/abs/2106.11113
Warning
This implementation is under development and subject to change.
- Parameters:
env_name¶ (
str) – Name of the environment used to initialize embeddingsembedding_dim¶ (
int) – Dimension of the node embeddingsnum_encoder_layers¶ (
int) – Number of layers in the encodernum_heads¶ (
int) – Number of heads in the attention layersnormalization¶ (
str) – Normalization type in the attention layers**kwargs¶ – keyword arguments passed to the AutoregressivePolicy
Default paarameters are adopted from the original implementation.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- class rl4co.models.zoo.matnet.encoder.MatNetATSPInitEmbedding(embedding_dim, mode='RandomOneHot')[source]¶
Bases:
ModulePreparing the initial row and column embeddings for ATSP.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(td)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class rl4co.models.zoo.matnet.encoder.MatNetCrossMHA(embedding_dim, num_heads, bias=True, mixer_hidden_dim=16, mix1_init=0.7071067811865476, mix2_init=0.25)[source]¶
Bases:
ModuleInitializes internal Module state, shared by both nn.Module and ScriptModule.
- class rl4co.models.zoo.matnet.encoder.MatNetEncoder(embedding_dim=256, num_heads=16, num_layers=5, normalization='instance', feed_forward_hidden=512, init_embedding=None, init_embedding_kwargs=None)[source]¶
Bases:
ModuleInitializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(td)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class rl4co.models.zoo.matnet.encoder.MatNetMHA(embedding_dim, num_heads, bias=True)[source]¶
Bases:
ModuleInitializes internal Module state, shared by both nn.Module and ScriptModule.
- class rl4co.models.zoo.matnet.encoder.MatNetMHALayer(embedding_dim, num_heads, bias=True, feed_forward_hidden=512, normalization='instance')[source]¶
Bases:
ModuleInitializes internal Module state, shared by both nn.Module and ScriptModule.
- class rl4co.models.zoo.matnet.encoder.MatNetMHANetwork(embedding_dim=128, num_heads=8, num_layers=3, normalization='batch', feed_forward_hidden=512)[source]¶
Bases:
ModuleInitializes internal Module state, shared by both nn.Module and ScriptModule.
- class rl4co.models.zoo.matnet.decoder.MatNetDecoder(env_name, embedding_dim, num_heads, use_graph_context=True, select_start_nodes_fn=<function select_start_nodes>, linear_bias=False, context_embedding=None, dynamic_embedding=None, **logit_attn_kwargs)[source]¶
Bases:
AutoregressiveDecoderInitializes internal Module state, shared by both nn.Module and ScriptModule.
Multi-Decoder Attention Model (MDAM)¶
- class rl4co.models.zoo.mdam.model.MDAM(env, policy=None, baseline='rollout', policy_kwargs={}, baseline_kwargs={}, **kwargs)[source]¶
Bases:
REINFORCEMulti-Decoder Attention Model (MDAM) is a model to train multiple diverse policies, which effectively increases the chance of finding good solutions compared with existing methods that train only one policy. Reference link: https://arxiv.org/abs/2012.10638; Implementation reference: https://github.com/liangxinedu/MDAM.
- Parameters:
env¶ (
RL4COEnvBase) – Environment to use for the algorithmpolicy¶ (
MDAMPolicy) – Policy to use for the algorithmbaseline¶ (
Union[REINFORCEBaseline,str]) – REINFORCE baseline. Defaults to rollout (1 epoch of exponential, then greedy rollout baseline)policy_kwargs¶ – Keyword arguments for policy
baseline_kwargs¶ – Keyword arguments for baseline
**kwargs¶ – Keyword arguments passed to the superclass
- class rl4co.models.zoo.mdam.policy.MDAMPolicy(env_name, embedding_dim=128, num_encoder_layers=3, num_heads=8, normalization='batch', **kwargs)[source]¶
Bases:
AutoregressivePolicyMulti-Decoder Attention Model (MDAM) policy. Args:
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(td, env=None, phase='train', return_actions=False, **decoder_kwargs)[source]¶
Forward pass of the policy.
- Parameters:
td¶ (
TensorDict) – TensorDict containing the environment stateenv¶ (
Union[str,RL4COEnvBase]) – Environment to use for decodingphase¶ (
str) – Phase of the algorithm (train, val, test)return_actions¶ (
bool) – Whether to return the actionsreturn_entropy¶ – Whether to return the entropy
decoder_kwargs¶ – Keyword arguments for the decoder. See
rl4co.models.zoo.common.autoregressive.decoder.AutoregressiveDecoder
- Returns:
Dictionary containing the reward, log likelihood, and optionally the actions and entropy
- Return type:
out
- class rl4co.models.zoo.mdam.encoder.GraphAttentionEncoder(num_heads, embed_dim, num_layers, node_dim=None, normalization='batch', feed_forward_hidden=512, sdpa_fn=None)[source]¶
Bases:
ModuleInitializes internal Module state, shared by both nn.Module and ScriptModule.
- class rl4co.models.zoo.mdam.decoder.Decoder(env_name, embedding_dim, num_heads, num_paths=5, mask_inner=True, mask_logits=True, eg_step_gap=200, tanh_clipping=10.0, force_flash_attn=False, shrink_size=None, train_decode_type='sampling', val_decode_type='greedy', test_decode_type='greedy')[source]¶
Bases:
ModuleInitializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(td, encoded_inputs, env, attn, V, h_old, **decoder_kwargs)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
POMO¶
- class rl4co.models.zoo.pomo.model.POMO(env, policy=None, policy_kwargs={}, baseline='shared', num_augment=8, use_dihedral_8=True, num_starts=None, select_start_nodes_fn=<function select_start_nodes>, **kwargs)[source]¶
Bases:
REINFORCEPOMO Model for neural combinatorial optimization based on REINFORCE Based on Kwon et al. (2020) http://arxiv.org/abs/2010.16011.
- Parameters:
env¶ (
RL4COEnvBase) – TorchRL Environmentpolicy¶ (
Union[Module,POMOPolicy]) – Policy to use for the algorithmpolicy_kwargs¶ – Keyword arguments for policy
baseline¶ (
str) – Baseline to use for the algorithm. Note that POMO only supports shared baseline, so we will throw an error if anything else is passed.num_augment¶ (
int) – Number of augmentations (used only for validation and test)use_dihedral_8¶ (
bool) – Whether to use dihedral 8 augmentationnum_starts¶ (
int) – Number of starts for multi-start. If None, use the number of available actionsselect_start_nodes_fn¶ (
callable) – Function to select the start nodes for the environment defaulting toselect_start_nodes()**kwargs¶ – Keyword arguments passed to the superclass
Shared step between train/val/test. To be implemented in subclass
- class rl4co.models.zoo.pomo.policy.POMOPolicy(env_name, embedding_dim=128, num_encoder_layers=6, num_heads=8, normalization='instance', use_graph_context=False, select_start_nodes_fn=<function select_start_nodes>, **kwargs)[source]¶
Bases:
AutoregressivePolicyPOMO model policy based on Kwon et al. (2020) http://arxiv.org/abs/2010.16011. We re-declare the most important arguments here for convenience as in the paper. See
AutoregressivePolicysuperclass for more details.Note
Although the policy is the base
AutoregressivePolicy, we use the default values used in the paper. Differently to the base class: - num_encoder_layers=6 (instead of 3) - normalization=”instance” (instead of “batch”) - use_graph_context=False (instead of True) The latter is due to the fact that the paper does not use the graph context in the policy, which seems to be helpful in overfitting to the training graph size.- Parameters:
env_name¶ (
str) – Name of the environment used to initialize embeddingsembedding_dim¶ (
int) – Dimension of the node embeddingsnum_encoder_layers¶ (
int) – Number of layers in the encodernum_heads¶ (
int) – Number of heads in the attention layersnormalization¶ (
str) – Normalization type in the attention layersselect_start_nodes_fn¶ (
callable) – Function to select the start nodes for the environment defaulting toselect_start_nodes()**kwargs¶ – keyword arguments passed to the
AutoregressivePolicy
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Pointer Network (PtrNet)¶
- class rl4co.models.zoo.ptrnet.model.PointerNetwork(env, policy=None, baseline='rollout', policy_kwargs={}, baseline_kwargs={}, **kwargs)[source]¶
Bases:
REINFORCEPointer Network for neural combinatorial optimization based on REINFORCE Based on Vinyals et al. (2015) https://arxiv.org/abs/1506.03134 Refactored from reference implementation: https://github.com/wouterkool/attention-learn-to-route
- Parameters:
env¶ (
RL4COEnvBase) – Environment to use for the algorithmpolicy¶ (
PointerNetworkPolicy) – Policy to use for the algorithmbaseline¶ (
Union[REINFORCEBaseline,str]) – REINFORCE baseline. Defaults to rollout (1 epoch of exponential, then greedy rollout baseline)policy_kwargs¶ – Keyword arguments for policy
baseline_kwargs¶ – Keyword arguments for baseline
**kwargs¶ – Keyword arguments passed to the superclass
- class rl4co.models.zoo.ptrnet.policy.PointerNetworkPolicy(env_name, embedding_dim=128, hidden_dim=128, tanh_clipping=10.0, mask_inner=True, mask_logits=True, **kwargs)[source]¶
Bases:
ModuleInitializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(td, env, phase='train', decode_type='sampling', eval_tours=None, **unused_kwargs)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class rl4co.models.zoo.ptrnet.encoder.Encoder(input_dim, hidden_dim)[source]¶
Bases:
ModuleMaps a graph represented as an input sequence to a hidden vector
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(x, hidden)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
Trainable initial hidden state
- class rl4co.models.zoo.ptrnet.decoder.Decoder(embedding_dim=128, hidden_dim=128, tanh_exploration=10.0, use_tanh=True, num_glimpses=1, mask_glimpses=True, mask_logits=True)[source]¶
Bases:
ModuleInitializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(decoder_input, embedded_inputs, hidden, context, decode_type='sampling', eval_tours=None)[source]¶
- Parameters:
decoder_input¶ – The initial input to the decoder size is [batch_size x embedding_dim]. Trainable parameter.
embedded_inputs¶ – [sourceL x batch_size x embedding_dim]
hidden¶ – the prev hidden state, size is [batch_size x hidden_dim]. Initially this is set to (enc_h[-1], enc_c[-1])
context¶ – encoder outputs, [sourceL x batch_size x hidden_dim]
- class rl4co.models.zoo.ptrnet.decoder.SimpleAttention(dim, use_tanh=False, C=10)[source]¶
Bases:
ModuleA generic attention module for a decoder in seq2seq
Initializes internal Module state, shared by both nn.Module and ScriptModule.
SymNCO¶
- class rl4co.models.zoo.symnco.model.SymNCO(env, policy=None, policy_kwargs={}, baseline='symnco', num_augment=4, alpha=0.2, beta=1, num_starts=0, **kwargs)[source]¶
Bases:
REINFORCESymNCO Model based on REINFORCE with shared baselines. Based on Kim et al. (2022) https://arxiv.org/abs/2205.13209.
- Parameters:
env¶ (
RL4COEnvBase) – TorchRL environment to use for the algorithmpolicy¶ (
Union[Module,SymNCOPolicy]) – Policy to use for the algorithmpolicy_kwargs¶ (
dict) – Keyword arguments for policynum_augment¶ (
int) – Number of augmentationsalpha¶ (
float) – weight for invariance lossbeta¶ (
float) – weight for solution symmetricity lossnum_starts¶ (
int) – Number of starts for multi-start. If None, use the number of available actions**kwargs¶ – Keyword arguments passed to the superclass
Shared step between train/val/test. To be implemented in subclass
- class rl4co.models.zoo.symnco.policy.SymNCOPolicy(env_name, embedding_dim=128, num_encoder_layers=3, num_heads=8, normalization='batch', projection_head=None, use_projection_head=True, **kwargs)[source]¶
Bases:
AutoregressivePolicySymNCO Policy based on AutoregressivePolicy. This differs from the default
AutoregressivePolicyin that it projects the initial embeddings to a lower dimension using a projection head and returns it. This is used in the SymNCO algorithm to compute the invariance loss. Based on Kim et al. (2022) https://arxiv.org/abs/2205.13209.- Parameters:
env_name¶ (
str) – Name of the environmentembedding_dim¶ (
int) – Dimension of the embeddingnum_encoder_layers¶ (
int) – Number of layers in the encodernum_heads¶ (
int) – Number of heads in the encodernormalization¶ (
str) – Normalization to use in the encoderprojection_head¶ (
Module) – Projection head to useuse_projection_head¶ (
bool) – Whether to use projection head**kwargs¶ – Keyword arguments passed to the superclass
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(td, env=None, phase='train', return_actions=False, return_entropy=False, return_init_embeds=True, **decoder_kwargs)[source]¶
Forward pass of the policy.
- Parameters:
td¶ (
TensorDict) – TensorDict containing the environment stateenv¶ (
Union[str,RL4COEnvBase]) – Environment to use for decodingphase¶ (
str) – Phase of the algorithm (train, val, test)return_actions¶ (
bool) – Whether to return the actionsreturn_entropy¶ (
bool) – Whether to return the entropydecoder_kwargs¶ – Keyword arguments for the decoder. See
rl4co.models.zoo.common.autoregressive.decoder.AutoregressiveDecoder
- Returns:
Dictionary containing the reward, log likelihood, and optionally the actions and entropy
- Return type:
out
- rl4co.models.zoo.symnco.losses.invariance_loss(proj_embed, num_augment)[source]¶
Loss for invariant representation on projected nodes Corresponds to L_inv in the SymNCO paper
Search Methods¶
Active Search (AS)¶
- class rl4co.models.zoo.active_search.search.ActiveSearch(env, policy, dataset, batch_size=1, max_iters=200, augment_size=8, augment_dihedral=True, num_parallel_runs=1, max_runtime=86400, save_path=None, optimizer='Adam', optimizer_kwargs={'lr': 0.00026, 'weight_decay': 1e-06}, **kwargs)[source]¶
Bases:
SearchBaseActive Search for Neural Combination Optimization from Bello et al. (2016). Fine-tunes the whole policy network (encoder + decoder) on a batch of instances. Reference: https://arxiv.org/abs/1611.09940
- Parameters:
env¶ – RL4CO environment to be solved
policy¶ – policy network
dataset¶ (
Union[Dataset,str]) – dataset to be used for trainingbatch_size¶ (
int) – batch size for trainingmax_iters¶ (
int) – maximum number of iterationsaugment_size¶ (
int) – number of augmentations per stateaugment_dihedral¶ (
bool) – whether to augment with dihedral rotationsparallel_runs¶ – number of parallel runs
max_runtime¶ (
int) – maximum runtime in secondssave_path¶ (
str) – path to save solution checkpointsoptimizer¶ (
Union[str,Optimizer,partial]) – optimizer to use for trainingoptimizer_kwargs¶ (
dict) – keyword arguments for optimizer**kwargs¶ – additional keyword arguments
- on_train_batch_end(outputs, batch, batch_idx)[source]¶
We store the best solution and reward found.
- Return type:
None
- on_train_batch_start(batch, batch_idx)[source]¶
Called before training (i.e. search) for a new batch begins. We re-load the original policy state dict and configure the optimizer.
- on_train_epoch_end()[source]¶
Called when the training ends. If the epoch ends, it means we have finished searching over the instances, thus the trainer should stop.
- Return type:
None
Efficent Active Search (EAS)¶
- class rl4co.models.zoo.eas.search.EAS(env, policy, dataset, use_eas_embedding=True, use_eas_layer=False, eas_emb_cache_keys=['logit_key'], eas_lambda=0.013, batch_size=2, max_iters=200, augment_size=8, augment_dihedral=True, num_parallel_runs=1, baseline='multistart', max_runtime=86400, save_path=None, optimizer='Adam', optimizer_kwargs={'lr': 0.0041, 'weight_decay': 1e-06}, verbose=True, **kwargs)[source]¶
Bases:
SearchBaseEfficient Active Search for Neural Combination Optimization from Hottung et al. (2022). Fine-tunes a subset of parameters (such as node embeddings or newly added layers) thus avoiding expensive re-encoding of the problem. Reference: https://openreview.net/pdf?id=nO5caZwFwYu
- Parameters:
env¶ – RL4CO environment to be solved
policy¶ – policy network
dataset¶ (
Union[Dataset,str]) – dataset to be used for traininguse_eas_embedding¶ (
bool) – whether to use EAS embedding (EASEmb)use_eas_layer¶ (
bool) – whether to use EAS layer (EASLay)eas_emb_cache_keys¶ (
List[str]) – keys to cache in the embeddingeas_lambda¶ (
float) – lambda parameter for IL lossbatch_size¶ (
int) – batch size for trainingmax_iters¶ (
int) – maximum number of iterationsaugment_size¶ (
int) – number of augmentations per stateaugment_dihedral¶ (
bool) – whether to augment with dihedral rotationsparallel_runs¶ – number of parallel runs
baseline¶ (
str) – REINFORCE baseline type (multistart, symmetric, full)max_runtime¶ (
int) – maximum runtime in secondssave_path¶ (
str) – path to save solution checkpointsoptimizer¶ (
Union[str,Optimizer,partial]) – optimizer to use for trainingoptimizer_kwargs¶ (
dict) – keyword arguments for optimizerverbose¶ (
bool) – whether to print progress for each iteration
- on_train_batch_end(outputs, batch, batch_idx)[source]¶
We store the best solution and reward found.
- Return type:
None
- on_train_batch_start(batch, batch_idx)[source]¶
Called before training (i.e. search) for a new batch begins. We re-load the original policy state dict and configure all parameters not to require gradients. We do the rest in the training step.
- class rl4co.models.zoo.eas.search.EASEmb(*args, **kwargs)[source]¶
Bases:
EASEAS with embedding adaptation
- class rl4co.models.zoo.eas.search.EASLay(*args, **kwargs)[source]¶
Bases:
EASEAS with layer adaptation
- rl4co.models.zoo.eas.decoder.forward_eas(self, td, cached_embeds, best_solutions, iter_count=0, env=None, decode_type='sampling_multistart', num_starts=None, softmax_temp=None, **unused_kwargs)[source]¶
Forward pass of the decoder Given the environment state and the pre-computed embeddings, compute the logits and sample actions
- Parameters:
td¶ (
TensorDict) – Input TensorDict containing the environment stateembeddings¶ – Precomputed embeddings for the nodes. Can be already precomputed cached in form of q, k, v and
env¶ (
Union[str,RL4COEnvBase]) – Environment to use for decoding. If None, the environment is instantiated from env_name. Note that it is more efficient to pass an already instantiated environment each time for fine-grained controldecode_type¶ (
str) – Type of decoding to use. Can be one of: - “sampling”: sample from the logits - “greedy”: take the argmax of the logits - “multistart_sampling”: sample as sampling, but with multi-start decoding - “multistart_greedy”: sample as greedy, but with multi-start decodingnum_starts¶ (
int) – Number of multi-starts to use. If None, will be calculated from the action masksoftmax_temp¶ (
float) – Temperature for the softmax. If None, default softmax is used from the LogitAttention modulecalc_reward¶ – Whether to calculate the reward for the decoded sequence
- rl4co.models.zoo.eas.decoder.forward_logit_attn_eas_lay(self, query, key, value, logit_key, mask, softmax_temp=None)[source]¶
Add layer to the forward pass of logit attention, i.e. Single-head attention.
- class rl4co.models.zoo.eas.nn.EASLayerNet(num_instances, emb_dim)[source]¶
Bases:
ModuleInstantiate weights and biases for the added layer. The layer is defined as: h = relu(emb * W1 + b1); out = h * W2 + b2. Wrapping in nn.Parameter makes the parameters trainable and sets gradient to True.
- Parameters:
Initializes internal Module state, shared by both nn.Module and ScriptModule.