Shortcuts

Source code for rl4co.models.nn.graph.mpnn

from typing import Tuple, Union

import torch
import torch.nn as nn

from tensordict import TensorDict
from torch import Tensor
from torch_geometric.data import Batch, Data
from torch_geometric.nn import MessagePassing

from rl4co.models.nn.env_embeddings import env_init_embedding
from rl4co.models.nn.mlp import MLP
from rl4co.utils.pylogger import get_pylogger

log = get_pylogger(__name__)


[docs]class MessagePassingLayer(MessagePassing): def __init__( self, node_indim, node_outdim, edge_indim, edge_outdim, aggregation="add", residual=False, **mlp_params, ): super(MessagePassingLayer, self).__init__(aggr=aggregation) # Init message passing models self.edge_model = MLP( input_dim=edge_indim + 2 * node_indim, output_dim=edge_outdim, **mlp_params ) self.node_model = MLP( input_dim=edge_outdim + node_indim, output_dim=node_outdim, **mlp_params ) self.residual = residual
[docs] def forward(self, node_feature, edge_feature, edge_index, mask=None): # Message passing update_edge_feature = self.edge_update(node_feature, edge_feature, edge_index) update_node_feature = self.propagate( edge_index, x=node_feature, edge_features=update_edge_feature ) # Update with residual connection if self.residual: update_node_feature = update_node_feature + node_feature return update_node_feature, update_edge_feature
[docs] def edge_update(self, nf, ef, edge_index): row, col = edge_index x_i, x_j = nf[row], nf[col] uef = self.edge_model(torch.cat([x_i, x_j, ef], dim=-1)) return uef
[docs] def message(self, edge_features: torch.tensor): return edge_features
[docs] def update(self, aggr_msg: torch.tensor, x: torch.tensor): unf = self.node_model(torch.cat([x, aggr_msg], dim=-1)) return unf
[docs]class MessagePassingEncoder(nn.Module): def __init__( self, env_name: str, embedding_dim: int, num_nodes: int, num_layers: int, init_embedding: nn.Module = None, aggregation: str = "add", self_loop: bool = False, residual: bool = True, ): """ Note: - Support fully connected graph for now. """ super(MessagePassingEncoder, self).__init__() self.env_name = env_name self.init_embedding = ( env_init_embedding(self.env_name, {"embedding_dim": embedding_dim}) if init_embedding is None else init_embedding ) # Generate edge index for a fully connected graph adj_matrix = torch.ones(num_nodes, num_nodes) if self_loop: adj_matrix.fill_diagonal_(0) # No self-loops self.edge_index = torch.permute(torch.nonzero(adj_matrix), (1, 0)) # Init message passing models self.mpnn_layers = nn.ModuleList( [ MessagePassingLayer( node_indim=embedding_dim, node_outdim=embedding_dim, edge_indim=1, edge_outdim=1, aggregation=aggregation, residual=residual, ) for _ in range(num_layers) ] ) # Record parameters self.self_loop = self_loop # def forward(self, x, mask=None):
[docs] def forward( self, td: TensorDict, mask: Union[Tensor, None] = None ) -> Tuple[Tensor, Tensor]: init_h = self.init_embedding(td) num_node = init_h.size(-2) # Check to update the edge index with different number of node if num_node != self.edge_index.max().item() + 1: adj_matrix = torch.ones(num_node, num_node) if self.self_loop: adj_matrix.fill_diagonal_(0) edge_index = torch.permute(torch.nonzero(adj_matrix), (1, 0)) edge_index = edge_index.to(init_h.device) else: edge_index = self.edge_index.to(init_h.device) # Generate edge features: distance edge_feature = torch.norm( init_h[..., edge_index[0], :] - init_h[..., edge_index[1], :], dim=-1, keepdim=True, ) # Create the batched graph data_list = [ Data(x=x, edge_index=edge_index, edge_attr=edge_attr) for x, edge_attr in zip(init_h, edge_feature) ] data_batch = Batch.from_data_list(data_list) update_node_feature = data_batch.x update_edge_feature = data_batch.edge_attr edge_index = data_batch.edge_index # Message passing for layer in self.mpnn_layers: update_node_feature, update_edge_feature = layer( update_node_feature, update_edge_feature, edge_index ) # De-batch the graph input_size = init_h.size() update_node_feature = update_node_feature.view(*input_size) return update_node_feature, init_h
[docs] def edge_update(self, nf, ef, edge_index): row, col = edge_index x_i, x_j = nf[row], nf[col] uef = self.edge_model(torch.cat([x_i, x_j, ef], dim=-1)) return uef
[docs] def message(self, edge_features: torch.tensor): return edge_features
[docs] def update(self, aggr_msg: torch.tensor, x: torch.tensor): unf = self.node_model(torch.cat([x, aggr_msg], dim=-1)) return unf

© Copyright Federico Berto, Chuanbo Hua, Junyoung Park. Revision f4bc96ca.

Built with Sphinx using a theme provided by Read the Docs.