Source code for rl4co.models.zoo.symnco.model
from typing import Any, Union
import torch.nn as nn
from rl4co.data.transforms import StateAugmentation
from rl4co.envs.common.base import RL4COEnvBase
from rl4co.models.rl.reinforce.reinforce import REINFORCE
from rl4co.models.zoo.symnco.losses import (
invariance_loss,
problem_symmetricity_loss,
solution_symmetricity_loss,
)
from rl4co.models.zoo.symnco.policy import SymNCOPolicy
from rl4co.utils.ops import gather_by_index, get_num_starts, unbatchify
from rl4co.utils.pylogger import get_pylogger
log = get_pylogger(__name__)
[docs]class SymNCO(REINFORCE):
"""SymNCO Model based on REINFORCE with shared baselines.
Based on Kim et al. (2022) https://arxiv.org/abs/2205.13209.
Args:
env: TorchRL environment to use for the algorithm
policy: Policy to use for the algorithm
policy_kwargs: Keyword arguments for policy
num_augment: Number of augmentations
alpha: weight for invariance loss
beta: weight for solution symmetricity loss
num_starts: Number of starts for multi-start. If None, use the number of available actions
**kwargs: Keyword arguments passed to the superclass
"""
def __init__(
self,
env: RL4COEnvBase,
policy: Union[nn.Module, SymNCOPolicy] = None,
policy_kwargs: dict = {},
baseline: str = "symnco",
num_augment: int = 4,
alpha: float = 0.2,
beta: float = 1,
num_starts: int = 0,
**kwargs,
):
self.save_hyperparameters(logger=False)
if policy is None:
policy = SymNCOPolicy(env.name, **policy_kwargs)
assert baseline == "symnco", "SymNCO only supports custom-symnco baseline"
baseline = "no" # Pass no baseline to superclass since there are multiple custom baselines
# Pass no baseline to superclass since there are multiple custom baselines
super().__init__(env, policy, baseline, **kwargs)
self.num_starts = num_starts
self.num_augment = num_augment
self.augment = StateAugmentation(self.env.name, num_augment=self.num_augment)
self.alpha = alpha # weight for invariance loss
self.beta = beta # weight for solution symmetricity loss
# Add `_multistart` to decode type for train, val and test in policy if num_starts > 1
if self.num_starts > 1:
for phase in ["train", "val", "test"]:
self.set_decode_type_multistart(phase)
[docs] def shared_step(
self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None
):
td = self.env.reset(batch)
n_aug, n_start = self.num_augment, self.num_starts
n_start = get_num_starts(td, self.env.name) if n_start is None else n_start
# Symmetric augmentation
if n_aug > 1:
td = self.augment(td)
# Evaluate policy
out = self.policy(td, self.env, phase=phase, num_starts=n_start)
# Unbatchify reward to [batch_size, n_start, n_aug].
reward = unbatchify(out["reward"], (n_start, n_aug))
# Main training loss
if phase == "train":
# [batch_size, n_start, n_aug]
ll = unbatchify(out["log_likelihood"], (n_start, n_aug))
# Calculate losses: problem symmetricity, solution symmetricity, invariance
loss_ps = problem_symmetricity_loss(reward, ll) if n_start > 1 else 0
loss_ss = solution_symmetricity_loss(reward, ll) if n_aug > 1 else 0
loss_inv = invariance_loss(out["proj_embeddings"], n_aug) if n_aug > 1 else 0
loss = loss_ps + self.beta * loss_ss + self.alpha * loss_inv
out.update(
{
"loss": loss,
"loss_ss": loss_ss,
"loss_ps": loss_ps,
"loss_inv": loss_inv,
}
)
# Log only during validation and test
else:
if n_start > 1:
# max multi-start reward
max_reward, max_idxs = reward.max(dim=1)
out.update({"max_reward": max_reward})
# Reshape batch to [batch, n_start, n_aug]
if out.get("actions", None) is not None:
actions = unbatchify(out["actions"], (n_start, n_aug))
out.update(
{"best_multistart_actions": gather_by_index(actions, max_idxs)}
)
out["actions"] = actions
# Get augmentation score only during inference
if n_aug > 1:
# If multistart is enabled, we use the best multistart rewards
reward_ = max_reward if n_start > 1 else reward
max_aug_reward, max_idxs = reward_.max(dim=1)
out.update({"max_aug_reward": max_aug_reward})
if out.get("best_multistart_actions", None) is not None:
out.update(
{
"best_aug_actions": gather_by_index(
out["best_multistart_actions"], max_idxs
)
}
)
metrics = self.log_metrics(out, phase, dataloader_idx=dataloader_idx)
return {"loss": out.get("loss", None), **metrics}