Shortcuts

Source code for rl4co.models.zoo.common.autoregressive.encoder

from typing import Tuple, Union

import torch.nn as nn

from tensordict import TensorDict
from torch import Tensor

from rl4co.envs import RL4COEnvBase
from rl4co.models.nn.env_embeddings import env_init_embedding
from rl4co.models.nn.graph.attnnet import GraphAttentionNetwork


[docs]class GraphAttentionEncoder(nn.Module): """Graph Attention Encoder as in Kool et al. (2019). Args: env_name: environment name to solve num_heads: Number of heads for the attention embedding_dim: Dimension of the embeddings num_layers: Number of layers for the encoder normalization: Normalization to use for the attention feed_forward_hidden: Hidden dimension for the feed-forward network init_embedding: Model to use for the initial embedding. If None, use the default embedding for the environment sdpa_fn: Scaled dot product function to use for the attention """ def __init__( self, env_name: [str, RL4COEnvBase], num_heads: int, embedding_dim: int, num_layers: int, normalization: str = "batch", feed_forward_hidden: int = 512, init_embedding: nn.Module = None, sdpa_fn=None, ): super(GraphAttentionEncoder, self).__init__() if isinstance(env_name, RL4COEnvBase): env_name = env_name.name self.env_name = env_name self.init_embedding = ( env_init_embedding(self.env_name, {"embedding_dim": embedding_dim}) if init_embedding is None else init_embedding ) self.net = GraphAttentionNetwork( num_heads, embedding_dim, num_layers, normalization, feed_forward_hidden, sdpa_fn=sdpa_fn, )
[docs] def forward( self, td: TensorDict, mask: Union[Tensor, None] = None ) -> Tuple[Tensor, Tensor]: """Forward pass of the encoder. Transform the input TensorDict into a latent representation. Args: td: Input TensorDict containing the environment state mask: Mask to apply to the attention Returns: h: Latent representation of the input init_h: Initial embedding of the input """ # Transfer to embedding space init_h = self.init_embedding(td) # Process embedding h = self.net(init_h, mask) # Return latent representation and initial embedding return h, init_h

© Copyright Federico Berto, Chuanbo Hua, Junyoung Park. Revision 14d072ed.

Built with Sphinx using a theme provided by Read the Docs.