Shortcuts

Source code for rl4co.models.zoo.ppo.decoder

from typing import Tuple, Union

import torch
from tensordict import TensorDict
from torch import Tensor

from rl4co.envs import RL4COEnvBase, get_env
from rl4co.models.nn.utils import get_log_likelihood
from rl4co.models.zoo.common.autoregressive import AutoregressiveDecoder


[docs]class PPODecoder(AutoregressiveDecoder):
[docs] def evaluate_action( self, td: TensorDict, embeddings: Tensor, action: Tensor, env: Union[str, RL4COEnvBase] = None, ) -> Tuple[Tensor, Tensor]: """Evaluate the (old) action to compute log likelihood of the actions and corresponding entropy Args: td: Input TensorDict containing the environment state embeddings: Precomputed embeddings for the nodes action: Action to evaluate (batch_size, seq_len) env: Environment to use for decoding. If None, the environment is instantiated from `env_name`. Note that it is more efficient to pass an already instantiated environment each time for fine-grained control Returns: log_p: Tensor of shape (batch_size, seq_len, num_nodes) containing the log-likehood of the actions entropy: Tensor of shape (batch_size, seq_len) containing the sampled actions """ # Instantiate environment if needed if isinstance(env, str) or env is None: env_name = self.env_name if env is None else env env = get_env(env_name) # Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step cached_embeds = self._precompute_cache(embeddings) log_p = [] decode_step = 0 while not td["done"].all(): log_p_, _ = self._get_log_p(cached_embeds, td) action_ = action[..., decode_step] td.set("action", action_) td = env.step(td)["next"] log_p.append(log_p_) decode_step += 1 # Note that the decoding steps may not be equal to the decoding steps of actions # due to the padded zeros in the actions # Compute log likelihood of the actions log_p = torch.stack(log_p, 1) # [batch_size, decoding steps, num_nodes] ll = get_log_likelihood( log_p, action[..., :decode_step], mask=None, return_sum=False ) # [batch_size, decoding steps] assert ll.isfinite().all(), "Log p is not finite" # compute entropy log_p = torch.nan_to_num(log_p, nan=0.0) entropy = -(log_p.exp() * log_p).sum(dim=-1) # [batch, decoder steps] entropy = entropy.sum(dim=1) # [batch] -- sum over decoding steps assert entropy.isfinite().all(), "Entropy is not finite" return ll, entropy

© Copyright Federico Berto, Chuanbo Hua, Junyoung Park. Revision f4bc96ca.

Built with Sphinx using a theme provided by Read the Docs.