Shortcuts

Source code for rl4co.models.nn.graph.attnnet

from typing import Callable, Optional

import torch.nn as nn

from torch import Tensor

from rl4co.models.nn.attention import MultiHeadAttention
from rl4co.models.nn.ops import Normalization, SkipConnection
from rl4co.utils.pylogger import get_pylogger

log = get_pylogger(__name__)


[docs]class MultiHeadAttentionLayer(nn.Sequential): """Multi-Head Attention Layer with normalization and feed-forward layer Args: num_heads: number of heads in the MHA embed_dim: dimension of the embeddings feed_forward_hidden: dimension of the hidden layer in the feed-forward layer normalization: type of normalization to use (batch, layer, none) sdpa_fn: scaled dot product attention function (SDPA) """ def __init__( self, num_heads: int, embed_dim: int, feed_forward_hidden: int = 512, normalization: Optional[str] = "batch", sdpa_fn: Optional[Callable] = None, ): super(MultiHeadAttentionLayer, self).__init__( SkipConnection(MultiHeadAttention(embed_dim, num_heads, sdpa_fn=sdpa_fn)), Normalization(embed_dim, normalization), SkipConnection( nn.Sequential( nn.Linear(embed_dim, feed_forward_hidden), nn.ReLU(), nn.Linear(feed_forward_hidden, embed_dim), ) if feed_forward_hidden > 0 else nn.Linear(embed_dim, embed_dim) ), Normalization(embed_dim, normalization), )
[docs]class GraphAttentionNetwork(nn.Module): """Graph Attention Network to encode embeddings with a series of MHA layers consisting of a MHA layer, normalization, feed-forward layer, and normalization. Similar to Transformer encoder, as used in Kool et al. (2019). Args: num_heads: number of heads in the MHA embedding_dim: dimension of the embeddings num_layers: number of MHA layers normalization: type of normalization to use (batch, layer, none) feed_forward_hidden: dimension of the hidden layer in the feed-forward layer sdpa_fn: scaled dot product attention function (SDPA) """ def __init__( self, num_heads: int, embedding_dim: int, num_layers: int, normalization: str = "batch", feed_forward_hidden: int = 512, sdpa_fn: Optional[Callable] = None, ): super(GraphAttentionNetwork, self).__init__() self.layers = nn.Sequential( *( MultiHeadAttentionLayer( num_heads, embedding_dim, feed_forward_hidden=feed_forward_hidden, normalization=normalization, sdpa_fn=sdpa_fn, ) for _ in range(num_layers) ) )
[docs] def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: """Forward pass of the encoder Args: x: [batch_size, graph_size, embed_dim] initial embeddings to process mask: [batch_size, graph_size, graph_size] mask for the input embeddings. Unused for now. """ assert mask is None, "Mask not yet supported!" h = self.layers(x) return h

© Copyright Federico Berto, Chuanbo Hua, Junyoung Park. Revision 14d072ed.

Built with Sphinx using a theme provided by Read the Docs.