Shortcuts

Source code for rl4co.models.zoo.am.policy

from rl4co.models.zoo.common.autoregressive import AutoregressivePolicy


[docs]class AttentionModelPolicy(AutoregressivePolicy): """Attention Model Policy based on Kool et al. (2019): https://arxiv.org/abs/1803.08475. We re-declare the most important arguments here for convenience as in the paper. See `AutoregressivePolicy` superclass for more details. Args: env_name: Name of the environment used to initialize embeddings embedding_dim: Dimension of the node embeddings num_encoder_layers: Number of layers in the encoder num_heads: Number of heads in the attention layers normalization: Normalization type in the attention layers **kwargs: keyword arguments passed to the `AutoregressivePolicy` """ def __init__( self, env_name: str, embedding_dim: int = 128, num_encoder_layers: int = 3, num_heads: int = 8, normalization: str = "batch", **kwargs, ): super(AttentionModelPolicy, self).__init__( env_name=env_name, embedding_dim=embedding_dim, num_encoder_layers=num_encoder_layers, num_heads=num_heads, normalization=normalization, **kwargs, )

© Copyright Federico Berto, Chuanbo Hua, Junyoung Park. Revision 14d072ed.

Built with Sphinx using a theme provided by Read the Docs.