Shortcuts

Source code for rl4co.models.nn.graph.gcn

from typing import Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from tensordict import TensorDict
from torch import Tensor
from torch_geometric.data import Batch, Data
from torch_geometric.nn import GCNConv

from rl4co.models.nn.env_embeddings import env_init_embedding
from rl4co.utils.pylogger import get_pylogger

log = get_pylogger(__name__)


[docs]class GCNEncoder(nn.Module): """Graph Convolutional Network to encode embeddings with a series of GCN layers Args: embedding_dim: dimension of the embeddings num_nodes: number of nodes in the graph num_gcn_layer: number of GCN layers self_loop: whether to add self loop in the graph residual: whether to use residual connection """ def __init__( self, env_name: str, embedding_dim: int, num_nodes: int, num_layers: int, init_embedding: nn.Module = None, self_loop: bool = False, residual: bool = True, ): super(GCNEncoder, self).__init__() self.env_name = env_name self.init_embedding = ( env_init_embedding(self.env_name, {"embedding_dim": embedding_dim}) if init_embedding is None else init_embedding ) # Generate edge index for a fully connected graph adj_matrix = torch.ones(num_nodes, num_nodes) if self_loop: adj_matrix.fill_diagonal_(0) # No self-loops self.edge_index = torch.permute(torch.nonzero(adj_matrix), (1, 0)) # Define the GCN layers self.gcn_layers = nn.ModuleList( [GCNConv(embedding_dim, embedding_dim) for _ in range(num_layers)] ) # Record parameters self.residual = residual self.self_loop = self_loop # def forward(self, x, node_feature, mask=None):
[docs] def forward( self, td: TensorDict, mask: Union[Tensor, None] = None ) -> Tuple[Tensor, Tensor]: """Forward pass of the encoder. Transform the input TensorDict into a latent representation. Args: td: Input TensorDict containing the environment state mask: Mask to apply to the attention Returns: h: Latent representation of the input init_h: Initial embedding of the input """ # Transfer to embedding space init_h = self.init_embedding(td) num_node = init_h.size(-2) # Check to update the edge index with different number of node if num_node != self.edge_index.max().item() + 1: adj_matrix = torch.ones(num_node, num_node) if self.self_loop: adj_matrix.fill_diagonal_(0) edge_index = torch.permute(torch.nonzero(adj_matrix), (1, 0)) edge_index = edge_index.to(init_h.device) else: edge_index = self.edge_index.to(init_h.device) # Create the batched graph data_list = [Data(x=x, edge_index=edge_index) for x in init_h] data_batch = Batch.from_data_list(data_list) # GCN process update_node_feature = data_batch.x edge_index = data_batch.edge_index for layer in self.gcn_layers[:-1]: update_node_feature = layer(update_node_feature, edge_index) update_node_feature = F.relu(update_node_feature) update_node_feature = F.dropout(update_node_feature, training=self.training) update_node_feature = self.gcn_layers[-1](update_node_feature, edge_index) # De-batch the graph input_size = init_h.size() update_node_feature = update_node_feature.view(*input_size) # Residual update_node_feature = update_node_feature + init_h return update_node_feature, init_h

© Copyright Federico Berto, Chuanbo Hua, Junyoung Park. Revision f4bc96ca.

Built with Sphinx using a theme provided by Read the Docs.