Shortcuts

Source code for rl4co.models.zoo.pomo.model

from typing import Any, Union

import torch.nn as nn

from rl4co.data.transforms import StateAugmentation
from rl4co.envs.common.base import RL4COEnvBase
from rl4co.models.rl.reinforce.reinforce import REINFORCE
from rl4co.models.zoo.pomo.policy import POMOPolicy
from rl4co.utils.ops import gather_by_index, get_num_starts, unbatchify


[docs]class POMO(REINFORCE): """POMO Model for neural combinatorial optimization based on REINFORCE Based on Kwon et al. (2020) http://arxiv.org/abs/2010.16011. Args: env: TorchRL Environment policy: Policy to use for the algorithm policy_kwargs: Keyword arguments for policy baseline: Baseline to use for the algorithm. Note that POMO only supports shared baseline, so we will throw an error if anything else is passed. num_augment: Number of augmentations (used only for validation and test) use_dihedral_8: Whether to use dihedral 8 augmentation num_starts: Number of starts for multi-start. If None, use the number of available actions **kwargs: Keyword arguments passed to the superclass """ def __init__( self, env: RL4COEnvBase, policy: Union[nn.Module, POMOPolicy] = None, policy_kwargs={}, baseline: str = "shared", num_augment: int = 8, use_dihedral_8: bool = True, num_starts: int = None, **kwargs, ): self.save_hyperparameters(logger=False) if policy is None: policy = POMOPolicy(env.name, **policy_kwargs) assert baseline == "shared", "POMO only supports shared baseline" # Initialize with the shared baseline super(POMO, self).__init__(env, policy, baseline, **kwargs) self.num_starts = num_starts self.num_augment = num_augment self.augment = StateAugmentation( self.env.name, num_augment=self.num_augment, use_dihedral_8=use_dihedral_8 ) # Add `_multistart` to decode type for train, val and test in policy for phase in ["train", "val", "test"]: self.set_decode_type_multistart(phase)
[docs] def shared_step(self, batch: Any, batch_idx: int, phase: str): td = self.env.reset(batch) n_aug, n_start = self.num_augment, self.num_starts n_start = get_num_starts(td) if n_start is None else n_start # During training, we do not augment the data if phase == "train": n_aug = 0 elif n_aug > 1: td = self.augment(td) # Evaluate policy out = self.policy(td, self.env, phase=phase, num_starts=n_start) # Unbatchify reward to [batch_size, num_augment, num_starts]. reward = unbatchify(out["reward"], (n_start, n_aug)) # Training phase if phase == "train": assert n_start > 1, "num_starts must be > 1 during training" log_likelihood = unbatchify(out["log_likelihood"], (n_start, n_aug)) self.calculate_loss(td, batch, out, reward, log_likelihood) # Get multi-start (=POMO) rewards and best actions only during validation and test else: if n_start > 1: # max multi-start reward max_reward, max_idxs = reward.max(dim=1) out.update({"max_reward": max_reward}) # Reshape batch to [batch, n_start, n_aug] if out.get("actions", None) is not None: actions = unbatchify(out["actions"], (n_start, n_aug)) out.update( {"best_multistart_actions": gather_by_index(actions, max_idxs)} ) out["actions"] = actions # Get augmentation score only during inference if n_aug > 1: # If multistart is enabled, we use the best multistart rewards reward_ = max_reward if n_start > 1 else reward max_aug_reward, max_idxs = reward_.max(dim=1) out.update({"max_aug_reward": max_aug_reward}) if out.get("best_multistart_actions", None) is not None: out.update( { "best_aug_actions": gather_by_index( out["best_multistart_actions"], max_idxs ) } ) metrics = self.log_metrics(out, phase) return {"loss": out.get("loss", None), **metrics}

© Copyright Federico Berto, Chuanbo Hua, Junyoung Park. Revision f4bc96ca.

Built with Sphinx using a theme provided by Read the Docs.