Shortcuts

Source code for rl4co.models.nn.utils

from typing import Union

import torch
import torch.nn as nn
from tensordict import TensorDict

from rl4co.envs import RL4COEnvBase, get_env
from rl4co.utils import get_pylogger

log = get_pylogger(__name__)


[docs]def get_log_likelihood(log_p, actions, mask, return_sum: bool = True): """Get log likelihood of selected actions""" log_p = log_p.gather(2, actions.unsqueeze(-1)).squeeze(-1) # Optional: mask out actions irrelevant to objective so they do not get reinforced if mask is not None: log_p[~mask] = 0 assert ( log_p > -1000 ).data.all(), "Logprobs should not be -inf, check sampling procedure!" # Calculate log_likelihood if return_sum: return log_p.sum(1) # [batch] else: return log_p # [batch, decode_len]
[docs]def decode_probs(probs, mask, decode_type="sampling"): """Decode probabilities to select actions with mask""" assert (probs == probs).all(), "Probs should not contain any nans" if "greedy" in decode_type: _, selected = probs.max(1) assert not mask.gather( 1, selected.unsqueeze(-1) ).data.any(), "Decode greedy: infeasible action has maximum probability" elif "sampling" in decode_type: selected = torch.multinomial(probs, 1).squeeze(1) while mask.gather(1, selected.unsqueeze(-1)).data.any(): log.info("Sampled bad values, resampling!") selected = probs.multinomial(1).squeeze(1) else: assert False, "Unknown decode type: {}".format(decode_type) return selected
[docs]def random_policy(td): """Helper function to select a random action from available actions""" action = torch.multinomial(td["action_mask"].float(), 1).squeeze(-1) td.set("action", action) return td
[docs]def rollout(env, td, policy, max_steps: int = None): """Helper function to rollout a policy. Currently, TorchRL does not allow to step over envs when done with `env.rollout()`. We need this because for environments that complete at different steps. """ max_steps = float("inf") if max_steps is None else max_steps actions = [] steps = 0 while not td["done"].all(): td = policy(td) actions.append(td["action"]) td = env.step(td)["next"] steps += 1 if steps > max_steps: break return ( env.get_reward(td, torch.stack(actions, dim=1)), td, torch.stack(actions, dim=1), )
[docs]class RandomPolicy(nn.Module): """ Random Policy Class that randomly select actions from the action space This policy can be useful to check the sanity of the environment during development We match the function signature of forward to the one of the AutoregressivePolicy class """ def __init__(self, env_name=None): super().__init__() self.env_name = env_name
[docs] def forward( self, td: TensorDict, env: Union[str, RL4COEnvBase] = None, max_steps: int = None, ): # Instantiate environment if needed if isinstance(env, str) or env is None: env_name = self.env_name if env is None else env log.info(f"Instantiated environment not provided; instantiating {env_name}") env = get_env(env_name) return rollout(env, td, random_policy, max_steps=max_steps)

© Copyright Federico Berto, Chuanbo Hua, Junyoung Park. Revision 14d072ed.

Built with Sphinx using a theme provided by Read the Docs.