Shortcuts

Source code for rl4co.models.zoo.common.autoregressive.decoder

from dataclasses import dataclass
from typing import Tuple, Union

import torch
import torch.nn as nn

from einops import rearrange
from tensordict import TensorDict
from torch import Tensor

from rl4co.envs import RL4COEnvBase, get_env
from rl4co.models.nn.attention import LogitAttention
from rl4co.models.nn.env_embeddings import env_context_embedding, env_dynamic_embedding
from rl4co.models.nn.env_embeddings.dynamic import StaticEmbedding
from rl4co.models.nn.utils import decode_probs, get_log_likelihood
from rl4co.utils.ops import batchify, get_num_starts, select_start_nodes, unbatchify
from rl4co.utils.pylogger import get_pylogger

log = get_pylogger(__name__)


[docs]@dataclass class PrecomputedCache: node_embeddings: Tensor graph_context: Union[Tensor, float] glimpse_key: Tensor glimpse_val: Tensor logit_key: Tensor
[docs]class AutoregressiveDecoder(nn.Module): """Auto-regressive decoder for constructing solutions for combinatorial optimization problems. Given the environment state and the embeddings, compute the logits and sample actions autoregressively until all the environments in the batch have reached a terminal state. We additionally include support for multi-starts as it is more efficient to do so in the decoder as we can natively perform the attention computation. Note: There are major differences between this decoding and most RL problems. The most important one is that reward is not defined for partial solutions, hence we have to wait for the environment to reach a terminal state before we can compute the reward with `env.get_reward()`. Warning: We suppose environments in the `done` state are still available for sampling. This is because in NCO we need to wait for all the environments to reach a terminal state before we can stop the decoding process. This is in contrast with the TorchRL framework (at the moment) where the `env.rollout` function automatically resets. You may follow tighter integration with TorchRL here: https://github.com/kaist-silab/rl4co/issues/72. Args: env_name: environment name to solve embedding_dim: Dimension of the embeddings num_heads: Number of heads for the attention use_graph_context: Whether to use the initial graph context to modify the query select_start_nodes_fn: Function to select the start nodes for multi-start decoding linear_bias: Whether to use a bias in the linear projection of the embeddings context_embedding: Module to compute the context embedding. If None, the default is used dynamic_embedding: Module to compute the dynamic embedding. If None, the default is used """ def __init__( self, env_name: [str, RL4COEnvBase], embedding_dim: int, num_heads: int, use_graph_context: bool = True, select_start_nodes_fn: callable = select_start_nodes, linear_bias: bool = False, context_embedding: nn.Module = None, dynamic_embedding: nn.Module = None, **logit_attn_kwargs, ): super().__init__() if isinstance(env_name, RL4COEnvBase): env_name = env_name.name self.env_name = env_name self.embedding_dim = embedding_dim self.num_heads = num_heads assert embedding_dim % num_heads == 0 self.context_embedding = ( env_context_embedding(self.env_name, {"embedding_dim": embedding_dim}) if context_embedding is None else context_embedding ) self.dynamic_embedding = ( env_dynamic_embedding(self.env_name, {"embedding_dim": embedding_dim}) if dynamic_embedding is None else dynamic_embedding ) self.is_dynamic_embedding = ( False if isinstance(self.dynamic_embedding, StaticEmbedding) else True ) self.use_graph_context = use_graph_context # For each node we compute (glimpse key, glimpse value, logit key) so 3 * embedding_dim self.project_node_embeddings = nn.Linear( embedding_dim, 3 * embedding_dim, bias=linear_bias ) self.project_fixed_context = nn.Linear( embedding_dim, embedding_dim, bias=linear_bias ) # MHA self.logit_attention = LogitAttention( embedding_dim, num_heads, **logit_attn_kwargs ) self.select_start_nodes_fn = select_start_nodes_fn
[docs] def forward( self, td: TensorDict, embeddings: Tensor, env: Union[str, RL4COEnvBase] = None, decode_type: str = "sampling", num_starts: int = None, softmax_temp: float = None, calc_reward: bool = True, ) -> Tuple[Tensor, Tensor, TensorDict]: """Forward pass of the decoder Given the environment state and the pre-computed embeddings, compute the logits and sample actions Args: td: Input TensorDict containing the environment state embeddings: Precomputed embeddings for the nodes env: Environment to use for decoding. If None, the environment is instantiated from `env_name`. Note that it is more efficient to pass an already instantiated environment each time for fine-grained control decode_type: Type of decoding to use. Can be one of: - "sampling": sample from the logits - "greedy": take the argmax of the logits - "multistart_sampling": sample as sampling, but with multi-start decoding - "multistart_greedy": sample as greedy, but with multi-start decoding num_starts: Number of multi-starts to use. If None, will be calculated from the action mask softmax_temp: Temperature for the softmax. If None, default softmax is used from the `LogitAttention` module calc_reward: Whether to calculate the reward for the decoded sequence Returns: outputs: Tensor of shape (batch_size, seq_len, num_nodes) containing the logits actions: Tensor of shape (batch_size, seq_len) containing the sampled actions td: TensorDict containing the environment state after decoding """ # Instantiate environment if needed if isinstance(env, str): env_name = self.env_name if env is None else env env = get_env(env_name) # Multi-start decoding. If num_starts is None, we use the number of actions in the action mask if "multistart" in decode_type: if num_starts is None: num_starts = get_num_starts(td, env.name) else: if num_starts is not None: if num_starts > 1: log.warn( f"num_starts={num_starts} is ignored for decode_type={decode_type}" ) num_starts = 0 # Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step cached_embeds = self._precompute_cache(embeddings, td=td) # Collect outputs outputs = [] actions = [] # Multi-start decoding: first action is chosen by ad-hoc node selection if num_starts > 1 or "multistart" in decode_type: action = self.select_start_nodes_fn(td, env, num_starts=num_starts) # Expand td to batch_size * num_starts td = batchify(td, num_starts) td.set("action", action) td = env.step(td)["next"] log_p = torch.zeros_like( td["action_mask"], device=td.device ) # first log_p is 0, so p = log_p.exp() = 1 outputs.append(log_p) actions.append(action) # Main decoding: loop until all sequences are done while not td["done"].all(): log_p, mask = self._get_log_p(cached_embeds, td, softmax_temp, num_starts) # Select the indices of the next nodes in the sequences, result (batch_size) long action = decode_probs(log_p.exp(), mask, decode_type=decode_type) td.set("action", action) td = env.step(td)["next"] # Collect output of step outputs.append(log_p) actions.append(action) assert ( len(outputs) > 0 ), "No outputs were collected because all environments were done. Check your initial state" outputs, actions = torch.stack(outputs, 1), torch.stack(actions, 1) if calc_reward: td.set("reward", env.get_reward(td, actions)) return outputs, actions, td
def _precompute_cache( self, embeddings: Tensor, td: TensorDict = None, ): """Compute the cached embeddings for the attention Args: embeddings: Precomputed embeddings for the nodes td: TensorDict containing the environment state. This one is not used in this class. However, passing Tensordict can be useful in child classes. """ # The projection of the node embeddings for the attention is calculated once up front ( glimpse_key_fixed, glimpse_val_fixed, logit_key_fixed, ) = self.project_node_embeddings(embeddings).chunk(3, dim=-1) # Optionally disable the graph context from the initial embedding as done in POMO if self.use_graph_context: graph_context = self.project_fixed_context(embeddings.mean(1)) else: graph_context = 0 # Organize in a dataclass for easy access cached_embeds = PrecomputedCache( node_embeddings=embeddings, graph_context=graph_context, glimpse_key=glimpse_key_fixed, glimpse_val=glimpse_val_fixed, logit_key=logit_key_fixed, ) return cached_embeds def _get_log_p( self, cached: PrecomputedCache, td: TensorDict, softmax_temp: float = None, num_starts: int = 0, ): """Compute the log probabilities of the next actions given the current state Args: cache: Precomputed embeddings td: TensorDict with the current environment state softmax_temp: Temperature for the softmax num_starts: Number of starts for the multi-start decoding """ # Get precomputed (cached) embeddings node_embeds_cache, graph_context_cache = ( cached.node_embeddings, cached.graph_context, ) glimpse_k_stat, glimpse_v_stat, logit_k_stat = ( cached.glimpse_key, cached.glimpse_val, cached.logit_key, ) # [B, N, H] has_dyn_emb_multi_start = self.is_dynamic_embedding and num_starts > 1 # Handle efficient multi-start decoding if has_dyn_emb_multi_start: # if num_starts > 0 and we have some dynamic embeddings, we need to reshape them to [B*S, ...] # since keys and values are not shared across starts (i.e. the episodes modify these embeddings at each step) glimpse_k_stat = batchify(glimpse_k_stat, num_starts) glimpse_v_stat = batchify(glimpse_v_stat, num_starts) logit_k_stat = batchify(logit_k_stat, num_starts) node_embeds_cache = batchify(node_embeds_cache, num_starts) graph_context_cache = ( batchify(graph_context_cache, num_starts) if isinstance(graph_context_cache, Tensor) else graph_context_cache ) elif num_starts > 1: td = unbatchify(td, num_starts) if isinstance(graph_context_cache, Tensor): # add a dimension for num_starts (will automatically be broadcasted during addition) graph_context_cache = graph_context_cache.unsqueeze(1) step_context = self.context_embedding(node_embeds_cache, td) glimpse_q = step_context + graph_context_cache glimpse_q = ( glimpse_q.unsqueeze(1) if glimpse_q.ndim == 2 else glimpse_q ) # add seq_len dim if not present # Compute dynamic embeddings and add to static embeddings glimpse_k_dyn, glimpse_v_dyn, logit_k_dyn = self.dynamic_embedding(td) glimpse_k = glimpse_k_stat + glimpse_k_dyn glimpse_v = glimpse_v_stat + glimpse_v_dyn logit_k = logit_k_stat + logit_k_dyn # Get the mask mask = ~td["action_mask"] # Compute logits log_p = self.logit_attention( glimpse_q, glimpse_k, glimpse_v, logit_k, mask, softmax_temp ) # Now we need to reshape the logits and log_p to [B*S,N,...] is num_starts > 1 without dynamic embeddings # note that rearranging order is important here if num_starts > 1 and not has_dyn_emb_multi_start: log_p = rearrange(log_p, "b s l -> (s b) l", s=num_starts) mask = rearrange(mask, "b s l -> (s b) l", s=num_starts) return log_p, mask
[docs] def evaluate_action( self, td: TensorDict, embeddings: Tensor, action: Tensor, env: Union[str, RL4COEnvBase] = None, ) -> Tuple[Tensor, Tensor]: """Evaluate the (old) action to compute log likelihood of the actions and corresponding entropy Args: td: Input TensorDict containing the environment state embeddings: Precomputed embeddings for the nodes action: Action to evaluate (batch_size, seq_len) env: Environment to use for decoding. If None, the environment is instantiated from `env_name`. Note that it is more efficient to pass an already instantiated environment each time for fine-grained control Returns: log_p: Tensor of shape (batch_size, seq_len, num_nodes) containing the log-likehood of the actions entropy: Tensor of shape (batch_size, seq_len) containing the sampled actions """ # Instantiate environment if needed if isinstance(env, str) or env is None: env_name = self.env_name if env is None else env env = get_env(env_name) # Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step cached_embeds = self._precompute_cache(embeddings) log_p = [] decode_step = 0 while not td["done"].all(): log_p_, _ = self._get_log_p(cached_embeds, td) action_ = action[..., decode_step] td.set("action", action_) td = env.step(td)["next"] log_p.append(log_p_) decode_step += 1 # Note that the decoding steps may not be equal to the decoding steps of actions # due to the padded zeros in the actions # Compute log likelihood of the actions log_p = torch.stack(log_p, 1) # [batch_size, decoding steps, num_nodes] ll = get_log_likelihood( log_p, action[..., :decode_step], mask=None, return_sum=False ) # [batch_size, decoding steps] assert ll.isfinite().all(), "Log p is not finite" # compute entropy log_p = torch.nan_to_num(log_p, nan=0.0) entropy = -(log_p.exp() * log_p).sum(dim=-1) # [batch, decoder steps] entropy = entropy.sum(dim=1) # [batch] -- sum over decoding steps assert entropy.isfinite().all(), "Entropy is not finite" return ll, entropy

© Copyright Federico Berto, Chuanbo Hua, Junyoung Park. Revision 14d072ed.

Built with Sphinx using a theme provided by Read the Docs.