Source code for rl4co.models.zoo.am.policy
from rl4co.models.zoo.common.autoregressive import AutoregressivePolicy
[docs]class AttentionModelPolicy(AutoregressivePolicy):
"""Attention Model Policy based on Kool et al. (2019): https://arxiv.org/abs/1803.08475.
We re-declare the most important arguments here for convenience as in the paper.
See `AutoregressivePolicy` superclass for more details.
Args:
env_name: Name of the environment used to initialize embeddings
embedding_dim: Dimension of the node embeddings
num_encoder_layers: Number of layers in the encoder
num_heads: Number of heads in the attention layers
normalization: Normalization type in the attention layers
**kwargs: keyword arguments passed to the `AutoregressivePolicy`
"""
def __init__(
self,
env_name: str,
embedding_dim: int = 128,
num_encoder_layers: int = 3,
num_heads: int = 8,
normalization: str = "batch",
**kwargs,
):
super(AttentionModelPolicy, self).__init__(
env_name=env_name,
embedding_dim=embedding_dim,
num_encoder_layers=num_encoder_layers,
num_heads=num_heads,
normalization=normalization,
**kwargs,
)