Shortcuts

Source code for rl4co.models.zoo.ptrnet.encoder

import math

import torch
import torch.nn as nn


[docs]class Encoder(nn.Module): """Maps a graph represented as an input sequence to a hidden vector""" def __init__(self, input_dim, hidden_dim): super(Encoder, self).__init__() self.hidden_dim = hidden_dim self.lstm = nn.LSTM(input_dim, hidden_dim) self.init_hx, self.init_cx = self.init_hidden(hidden_dim)
[docs] def forward(self, x, hidden): output, hidden = self.lstm(x, hidden) return output, hidden
[docs] def init_hidden(self, hidden_dim): """Trainable initial hidden state""" std = 1.0 / math.sqrt(hidden_dim) enc_init_hx = nn.Parameter(torch.FloatTensor(hidden_dim)) enc_init_hx.data.uniform_(-std, std) enc_init_cx = nn.Parameter(torch.FloatTensor(hidden_dim)) enc_init_cx.data.uniform_(-std, std) return enc_init_hx, enc_init_cx

© Copyright Federico Berto, Chuanbo Hua, Junyoung Park. Revision f4bc96ca.

Built with Sphinx using a theme provided by Read the Docs.