Source code for rl4co.models.zoo.pomo.model
from typing import Any, Union
import torch.nn as nn
from rl4co.data.transforms import StateAugmentation
from rl4co.envs.common.base import RL4COEnvBase
from rl4co.models.rl.reinforce.reinforce import REINFORCE
from rl4co.models.zoo.pomo.policy import POMOPolicy
from rl4co.utils.ops import gather_by_index, get_num_starts, unbatchify
[docs]class POMO(REINFORCE):
"""POMO Model for neural combinatorial optimization based on REINFORCE
Based on Kwon et al. (2020) http://arxiv.org/abs/2010.16011.
Args:
env: TorchRL Environment
policy: Policy to use for the algorithm
policy_kwargs: Keyword arguments for policy
baseline: Baseline to use for the algorithm. Note that POMO only supports shared baseline,
so we will throw an error if anything else is passed.
num_augment: Number of augmentations (used only for validation and test)
use_dihedral_8: Whether to use dihedral 8 augmentation
num_starts: Number of starts for multi-start. If None, use the number of available actions
**kwargs: Keyword arguments passed to the superclass
"""
def __init__(
self,
env: RL4COEnvBase,
policy: Union[nn.Module, POMOPolicy] = None,
policy_kwargs={},
baseline: str = "shared",
num_augment: int = 8,
use_dihedral_8: bool = True,
num_starts: int = None,
**kwargs,
):
self.save_hyperparameters(logger=False)
if policy is None:
policy = POMOPolicy(env.name, **policy_kwargs)
assert baseline == "shared", "POMO only supports shared baseline"
# Initialize with the shared baseline
super(POMO, self).__init__(env, policy, baseline, **kwargs)
self.num_starts = num_starts
self.num_augment = num_augment
self.augment = StateAugmentation(
self.env.name, num_augment=self.num_augment, use_dihedral_8=use_dihedral_8
)
# Add `_multistart` to decode type for train, val and test in policy
for phase in ["train", "val", "test"]:
self.set_decode_type_multistart(phase)
[docs] def shared_step(self, batch: Any, batch_idx: int, phase: str):
td = self.env.reset(batch)
n_aug, n_start = self.num_augment, self.num_starts
n_start = get_num_starts(td) if n_start is None else n_start
# During training, we do not augment the data
if phase == "train":
n_aug = 0
elif n_aug > 1:
td = self.augment(td)
# Evaluate policy
out = self.policy(td, self.env, phase=phase, num_starts=n_start)
# Unbatchify reward to [batch_size, num_augment, num_starts].
reward = unbatchify(out["reward"], (n_start, n_aug))
# Training phase
if phase == "train":
assert n_start > 1, "num_starts must be > 1 during training"
log_likelihood = unbatchify(out["log_likelihood"], (n_start, n_aug))
self.calculate_loss(td, batch, out, reward, log_likelihood)
# Get multi-start (=POMO) rewards and best actions only during validation and test
else:
if n_start > 1:
# max multi-start reward
max_reward, max_idxs = reward.max(dim=1)
out.update({"max_reward": max_reward})
# Reshape batch to [batch, n_start, n_aug]
if out.get("actions", None) is not None:
actions = unbatchify(out["actions"], (n_start, n_aug))
out.update(
{"best_multistart_actions": gather_by_index(actions, max_idxs)}
)
out["actions"] = actions
# Get augmentation score only during inference
if n_aug > 1:
# If multistart is enabled, we use the best multistart rewards
reward_ = max_reward if n_start > 1 else reward
max_aug_reward, max_idxs = reward_.max(dim=1)
out.update({"max_aug_reward": max_aug_reward})
if out.get("best_multistart_actions", None) is not None:
out.update(
{
"best_aug_actions": gather_by_index(
out["best_multistart_actions"], max_idxs
)
}
)
metrics = self.log_metrics(out, phase)
return {"loss": out.get("loss", None), **metrics}