Shortcuts

Source code for rl4co.models.zoo.pomo.model

from typing import Any, Union

import torch.nn as nn

from rl4co.data.transforms import StateAugmentation
from rl4co.envs.common.base import RL4COEnvBase
from rl4co.models.rl.reinforce.reinforce import REINFORCE
from rl4co.models.zoo.pomo.policy import POMOPolicy
from rl4co.utils.ops import (
    gather_by_index,
    get_num_starts,
    select_start_nodes,
    unbatchify,
)
from rl4co.utils.pylogger import get_pylogger

log = get_pylogger(__name__)


[docs]class POMO(REINFORCE): """POMO Model for neural combinatorial optimization based on REINFORCE Based on Kwon et al. (2020) http://arxiv.org/abs/2010.16011. Args: env: TorchRL Environment policy: Policy to use for the algorithm policy_kwargs: Keyword arguments for policy baseline: Baseline to use for the algorithm. Note that POMO only supports shared baseline, so we will throw an error if anything else is passed. num_augment: Number of augmentations (used only for validation and test) use_dihedral_8: Whether to use dihedral 8 augmentation num_starts: Number of starts for multi-start. If None, use the number of available actions select_start_nodes_fn: Function to select the start nodes for the environment defaulting to :func:`select_start_nodes` **kwargs: Keyword arguments passed to the superclass """ def __init__( self, env: RL4COEnvBase, policy: Union[nn.Module, POMOPolicy] = None, policy_kwargs={}, baseline: str = "shared", num_augment: int = 8, use_dihedral_8: bool = True, num_starts: int = None, select_start_nodes_fn: callable = select_start_nodes, **kwargs, ): self.save_hyperparameters(logger=False) # If select_start_nodes_fn is provided in policy_kwargs, we use that instead if "select_start_nodes_fn" in policy_kwargs: log.info( "Overriding select_start_nodes_fn in POMO with the one provided in policy_kwargs" ) policy_kwargs["select_start_nodes_fn"] = select_start_nodes_fn if policy is None: policy = POMOPolicy(env.name, **policy_kwargs) assert baseline == "shared", "POMO only supports shared baseline" # Initialize with the shared baseline super(POMO, self).__init__(env, policy, baseline, **kwargs) self.num_starts = num_starts self.num_augment = num_augment if self.num_augment > 1: self.augment = StateAugmentation( self.env.name, num_augment=self.num_augment, use_dihedral_8=use_dihedral_8 ) else: self.augment = None # Add `_multistart` to decode type for train, val and test in policy for phase in ["train", "val", "test"]: self.set_decode_type_multistart(phase)
[docs] def shared_step( self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None ): td = self.env.reset(batch) n_aug, n_start = self.num_augment, self.num_starts n_start = get_num_starts(td, self.env.name) if n_start is None else n_start # During training, we do not augment the data if phase == "train": n_aug = 0 elif n_aug > 1: td = self.augment(td) # Evaluate policy out = self.policy(td, self.env, phase=phase, num_starts=n_start) # Unbatchify reward to [batch_size, num_augment, num_starts]. reward = unbatchify(out["reward"], (n_aug, n_start)) # Training phase if phase == "train": assert n_start > 1, "num_starts must be > 1 during training" log_likelihood = unbatchify(out["log_likelihood"], (n_aug, n_start)) self.calculate_loss(td, batch, out, reward, log_likelihood) # Get multi-start (=POMO) rewards and best actions only during validation and test else: if n_start > 1: # max multi-start reward max_reward, max_idxs = reward.max(dim=-1) out.update({"max_reward": max_reward}) if out.get("actions", None) is not None: # Reshape batch to [batch_size, num_augment, num_starts, ...] actions = unbatchify(out["actions"], (n_aug, n_start)) out.update( {"best_multistart_actions": gather_by_index(actions, max_idxs)} ) out["actions"] = actions # Get augmentation score only during inference if n_aug > 1: # If multistart is enabled, we use the best multistart rewards reward_ = max_reward if n_start > 1 else reward max_aug_reward, max_idxs = reward_.max(dim=1) out.update({"max_aug_reward": max_aug_reward}) if out.get("actions", None) is not None: actions_ = ( out["best_multistart_actions"] if n_start > 1 else out["actions"] ) out.update({"best_aug_actions": gather_by_index(actions_, max_idxs)}) metrics = self.log_metrics(out, phase, dataloader_idx=dataloader_idx) return {"loss": out.get("loss", None), **metrics}

© Copyright Federico Berto, Chuanbo Hua, Junyoung Park. Revision 14d072ed.

Built with Sphinx using a theme provided by Read the Docs.