from dataclasses import dataclass
from typing import Tuple, Union
import torch
import torch.nn as nn
from einops import rearrange
from tensordict import TensorDict
from torch import Tensor
from rl4co.envs import RL4COEnvBase, get_env
from rl4co.models.nn.attention import LogitAttention
from rl4co.models.nn.env_embeddings import env_context_embedding, env_dynamic_embedding
from rl4co.models.nn.utils import decode_probs
from rl4co.utils.ops import batchify, get_num_starts, select_start_nodes, unbatchify
from rl4co.utils.pylogger import get_pylogger
log = get_pylogger(__name__)
[docs]@dataclass
class PrecomputedCache:
node_embeddings: Tensor
graph_context: Union[Tensor, float]
glimpse_key: Tensor
glimpse_val: Tensor
logit_key: Tensor
[docs]class AutoregressiveDecoder(nn.Module):
"""Auto-regressive decoder for constructing solutions for combinatorial optimization problems.
Given the environment state and the embeddings, compute the logits and sample actions autoregressively until
all the environments in the batch have reached a terminal state.
We additionally include support for multi-starts as it is more efficient to do so in the decoder as we can
natively perform the attention computation.
Note:
There are major differences between this decoding and most RL problems. The most important one is
that reward is not defined for partial solutions, hence we have to wait for the environment to reach a terminal
state before we can compute the reward with `env.get_reward()`.
Warning:
We suppose environments in the `done` state are still available for sampling. This is because in NCO we need to
wait for all the environments to reach a terminal state before we can stop the decoding process. This is in
contrast with the TorchRL framework (at the moment) where the `env.rollout` function automatically resets.
You may follow tighter integration with TorchRL here: https://github.com/kaist-silab/rl4co/issues/72.
Args:
env_name: environment name to solve
embedding_dim: Dimension of the embeddings
num_heads: Number of heads for the attention
use_graph_context: Whether to use the initial graph context to modify the query
select_start_nodes_fn: Function to select the start nodes for multi-start decoding
linear_bias: Whether to use a bias in the linear projection of the embeddings
context_embedding: Module to compute the context embedding. If None, the default is used
dynamic_embedding: Module to compute the dynamic embedding. If None, the default is used
"""
def __init__(
self,
env_name: str,
embedding_dim: int,
num_heads: int,
use_graph_context: bool = True,
select_start_nodes_fn: callable = select_start_nodes,
linear_bias: bool = False,
context_embedding: nn.Module = None,
dynamic_embedding: nn.Module = None,
**logit_attn_kwargs,
):
super().__init__()
self.env_name = env_name
self.embedding_dim = embedding_dim
self.num_heads = num_heads
assert embedding_dim % num_heads == 0
self.context_embedding = (
env_context_embedding(self.env_name, {"embedding_dim": embedding_dim})
if context_embedding is None
else context_embedding
)
self.dynamic_embedding = (
env_dynamic_embedding(self.env_name, {"embedding_dim": embedding_dim})
if dynamic_embedding is None
else dynamic_embedding
)
self.use_graph_context = use_graph_context
# For each node we compute (glimpse key, glimpse value, logit key) so 3 * embedding_dim
self.project_node_embeddings = nn.Linear(embedding_dim, 3 * embedding_dim, bias=linear_bias)
self.project_fixed_context = nn.Linear(embedding_dim, embedding_dim, bias=linear_bias)
# MHA
self.logit_attention = LogitAttention(embedding_dim, num_heads, **logit_attn_kwargs)
self.select_start_nodes_fn = select_start_nodes_fn
[docs] def forward(
self,
td: TensorDict,
embeddings: Tensor,
env: Union[str, RL4COEnvBase] = None,
decode_type: str = "sampling",
num_starts: int = None,
softmax_temp: float = None,
calc_reward: bool = True,
) -> Tuple[Tensor, Tensor, TensorDict]:
"""Forward pass of the decoder
Given the environment state and the pre-computed embeddings, compute the logits and sample actions
Args:
td: Input TensorDict containing the environment state
embeddings: Precomputed embeddings for the nodes
env: Environment to use for decoding. If None, the environment is instantiated from `env_name`. Note that
it is more efficient to pass an already instantiated environment each time for fine-grained control
decode_type: Type of decoding to use. Can be one of:
- "sampling": sample from the logits
- "greedy": take the argmax of the logits
- "multistart_sampling": sample as sampling, but with multi-start decoding
- "multistart_greedy": sample as greedy, but with multi-start decoding
num_starts: Number of multi-starts to use. If None, will be calculated from the action mask
softmax_temp: Temperature for the softmax. If None, default softmax is used from the `LogitAttention` module
calc_reward: Whether to calculate the reward for the decoded sequence
Returns:
outputs: Tensor of shape (batch_size, seq_len, num_nodes) containing the logits
actions: Tensor of shape (batch_size, seq_len) containing the sampled actions
td: TensorDict containing the environment state after decoding
"""
# Multi-start decoding. If num_starts is None, we use the number of actions in the action mask
if "multistart" in decode_type:
if num_starts is None:
num_starts = get_num_starts(td)
else:
if num_starts is not None:
if num_starts > 1:
log.warn(f"num_starts={num_starts} is ignored for decode_type={decode_type}")
num_starts = 0
# Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step
cached_embeds = self._precompute_cache(embeddings, td=td, num_starts=num_starts)
# Collect outputs
outputs = []
actions = []
# Instantiate environment if needed
if isinstance(env, str):
env_name = self.env_name if env is None else env
env = get_env(env_name)
# Multi-start decoding: first action is chosen by ad-hoc node selection
if num_starts > 1 or "multistart" in decode_type:
action = self.select_start_nodes_fn(td, env, num_nodes=num_starts)
# Expand td to batch_size * num_starts
td = batchify(td, num_starts)
td.set("action", action)
td = env.step(td)["next"]
log_p = torch.zeros_like(
td["action_mask"], device=td.device
) # first log_p is 0, so p = log_p.exp() = 1
outputs.append(log_p)
actions.append(action)
# Main decoding: loop until all sequences are done
while not td["done"].all():
log_p, mask = self._get_log_p(cached_embeds, td, softmax_temp, num_starts)
# Select the indices of the next nodes in the sequences, result (batch_size) long
action = decode_probs(log_p.exp(), mask, decode_type=decode_type)
td.set("action", action)
td = env.step(td)["next"]
# Collect output of step
outputs.append(log_p)
actions.append(action)
outputs, actions = torch.stack(outputs, 1), torch.stack(actions, 1)
if calc_reward:
td.set("reward", env.get_reward(td, actions))
return outputs, actions, td
def _precompute_cache(self, embeddings: Tensor, num_starts: int = 0, td: TensorDict = None):
"""Compute the cached embeddings for the attention
Args:
embeddings: Precomputed embeddings for the nodes
num_starts: Number of multi-starts to use. If 0, no multi-start decoding is used
td: TensorDict containing the environment state.
This one is not used in this class. However, passing Tensordict can be useful in child classes.
"""
# The projection of the node embeddings for the attention is calculated once up front
(
glimpse_key_fixed,
glimpse_val_fixed,
logit_key_fixed,
) = self.project_node_embeddings(
embeddings
).chunk(3, dim=-1)
# Optionally disable the graph context from the initial embedding as done in POMO
if self.use_graph_context:
graph_context = unbatchify(
batchify(self.project_fixed_context(embeddings.mean(1)), num_starts),
num_starts,
)
else:
graph_context = 0
# Organize in a dataclass for easy access
cached_embeds = PrecomputedCache(
node_embeddings=embeddings,
graph_context=graph_context,
glimpse_key=glimpse_key_fixed,
glimpse_val=glimpse_val_fixed,
logit_key=logit_key_fixed,
)
return cached_embeds
def _get_log_p(
self,
cached: PrecomputedCache,
td: TensorDict,
softmax_temp: float = None,
num_starts: int = 0,
):
"""Compute the log probabilities of the next actions given the current state
Args:
cache: Precomputed embeddings
td: TensorDict with the current environment state
softmax_temp: Temperature for the softmax
num_starts: Number of starts for the multi-start decoding
"""
# Unbatchify to [batch_size, num_starts, ...]. Has no effect if num_starts = 0
td_unbatch = unbatchify(td, num_starts)
step_context = self.context_embedding(cached.node_embeddings, td_unbatch)
glimpse_q = step_context + cached.graph_context
glimpse_q = glimpse_q.unsqueeze(1) if glimpse_q.ndim == 2 else glimpse_q
# Compute keys and values for the nodes
(
glimpse_key_dynamic,
glimpse_val_dynamic,
logit_key_dynamic,
) = self.dynamic_embedding(td_unbatch)
glimpse_k = cached.glimpse_key + glimpse_key_dynamic
glimpse_v = cached.glimpse_val + glimpse_val_dynamic
logit_k = cached.logit_key + logit_key_dynamic
# Get the mask
mask = ~td_unbatch["action_mask"]
# Compute logits
log_p = self.logit_attention(glimpse_q, glimpse_k, glimpse_v, logit_k, mask, softmax_temp)
# Now we need to reshape the logits and log_p to [batch_size*num_starts, num_nodes]
# Note that rearranging order is important here
log_p = rearrange(log_p, "b s l -> (s b) l") if num_starts > 1 else log_p
mask = rearrange(mask, "b s l -> (s b) l") if num_starts > 1 else mask
return log_p, mask