Shortcuts

Source code for rl4co.models.zoo.ptrnet.decoder

import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from rl4co.models.nn.utils import decode_probs


[docs]class SimpleAttention(nn.Module): """A generic attention module for a decoder in seq2seq""" def __init__(self, dim, use_tanh=False, C=10): super(SimpleAttention, self).__init__() self.use_tanh = use_tanh self.project_query = nn.Linear(dim, dim) self.project_ref = nn.Conv1d(dim, dim, 1, 1) self.C = C # tanh exploration self.v = nn.Parameter(torch.FloatTensor(dim)) self.v.data.uniform_(-(1.0 / math.sqrt(dim)), 1.0 / math.sqrt(dim))
[docs] def forward(self, query, ref): """ Args: query: is the hidden state of the decoder at the current time step. batch x dim ref: the set of hidden states from the encoder. sourceL x batch x hidden_dim """ # ref is now [batch_size x hidden_dim x sourceL] ref = ref.permute(1, 2, 0) q = self.project_query(query).unsqueeze(2) # batch x dim x 1 e = self.project_ref(ref) # batch_size x hidden_dim x sourceL # expand the query by sourceL # batch x dim x sourceL expanded_q = q.repeat(1, 1, e.size(2)) # batch x 1 x hidden_dim v_view = self.v.unsqueeze(0).expand(expanded_q.size(0), len(self.v)).unsqueeze(1) # [batch_size x 1 x hidden_dim] * [batch_size x hidden_dim x sourceL] u = torch.bmm(v_view, F.tanh(expanded_q + e)).squeeze(1) if self.use_tanh: logits = self.C * F.tanh(u) else: logits = u return e, logits
[docs]class Decoder(nn.Module): def __init__( self, embedding_dim: int = 128, hidden_dim: int = 128, tanh_exploration: float = 10.0, use_tanh: bool = True, num_glimpses=1, mask_glimpses=True, mask_logits=True, ): super(Decoder, self).__init__() self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim self.num_glimpses = num_glimpses self.mask_glimpses = mask_glimpses self.mask_logits = mask_logits self.use_tanh = use_tanh self.tanh_exploration = tanh_exploration self.lstm = nn.LSTMCell(embedding_dim, hidden_dim) self.pointer = SimpleAttention(hidden_dim, use_tanh=use_tanh, C=tanh_exploration) self.glimpse = SimpleAttention(hidden_dim, use_tanh=False)
[docs] def update_mask(self, mask, selected): return mask.clone().scatter_(1, selected.unsqueeze(-1), True)
[docs] def recurrence(self, x, h_in, prev_mask, prev_idxs, step, context): logit_mask = ( self.update_mask(prev_mask, prev_idxs) if prev_idxs is not None else prev_mask ) logits, h_out = self.calc_logits( x, h_in, logit_mask, context, self.mask_glimpses, self.mask_logits ) # Calculate log_softmax for better numerical stability log_p = torch.log_softmax(logits, dim=1) probs = log_p.exp() if not self.mask_logits: probs[logit_mask] = 0.0 return h_out, log_p, probs, logit_mask
[docs] def calc_logits( self, x, h_in, logit_mask, context, mask_glimpses=None, mask_logits=None ): if mask_glimpses is None: mask_glimpses = self.mask_glimpses if mask_logits is None: mask_logits = self.mask_logits hy, cy = self.lstm(x, h_in) g_l, h_out = hy, (hy, cy) for i in range(self.num_glimpses): ref, logits = self.glimpse(g_l, context) # For the glimpses, only mask before softmax so we have always an L1 norm 1 readout vector if mask_glimpses: logits[logit_mask] = float("-inf") # [batch_size x h_dim x sourceL] * [batch_size x sourceL x 1] = # [batch_size x h_dim x 1] g_l = torch.bmm(ref, F.softmax(logits, dim=1).unsqueeze(2)).squeeze(2) _, logits = self.pointer(g_l, context) # Masking before softmax makes probs sum to one if mask_logits: logits[logit_mask] = float("-inf") return logits, h_out
[docs] def forward( self, decoder_input, embedded_inputs, hidden, context, decode_type="sampling", eval_tours=None, ): """ Args: decoder_input: The initial input to the decoder size is [batch_size x embedding_dim]. Trainable parameter. embedded_inputs: [sourceL x batch_size x embedding_dim] hidden: the prev hidden state, size is [batch_size x hidden_dim]. Initially this is set to (enc_h[-1], enc_c[-1]) context: encoder outputs, [sourceL x batch_size x hidden_dim] """ batch_size = context.size(1) outputs = [] selections = [] steps = range(embedded_inputs.size(0)) idxs = None mask = torch.zeros( embedded_inputs.size(1), embedded_inputs.size(0), dtype=torch.bool, device=embedded_inputs.device, ) for i in steps: hidden, log_p, probs, mask = self.recurrence( decoder_input, hidden, mask, idxs, i, context ) # select the next inputs for the decoder [batch_size x hidden_dim] idxs = ( decode_probs(probs, mask, decode_type=decode_type) if eval_tours is None else eval_tours[:, i] ) idxs = ( idxs.detach() ) # Otherwise pytorch complains it want's a reward, todo implement this more properly? # Gather input embedding of selected decoder_input = torch.gather( embedded_inputs, 0, idxs.contiguous() .view(1, batch_size, 1) .expand(1, batch_size, *embedded_inputs.size()[2:]), ).squeeze(0) # use outs to point to next object outputs.append(log_p) selections.append(idxs) return (torch.stack(outputs, 1), torch.stack(selections, 1)), hidden

© Copyright Federico Berto, Chuanbo Hua, Junyoung Park. Revision 14d072ed.

Built with Sphinx using a theme provided by Read the Docs.