Shortcuts

Source code for rl4co.models.rl.reinforce.reinforce

from typing import IO, Any, Optional, Union, cast

import torch
import torch.nn as nn

from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
from lightning.pytorch.core.saving import _load_from_checkpoint
from tensordict import TensorDict
from typing_extensions import Self

from rl4co.envs.common.base import RL4COEnvBase
from rl4co.models.rl.common.base import RL4COLitModule
from rl4co.models.rl.reinforce.baselines import REINFORCEBaseline, get_reinforce_baseline
from rl4co.utils.lightning import get_lightning_device
from rl4co.utils.pylogger import get_pylogger

log = get_pylogger(__name__)


[docs]class REINFORCE(RL4COLitModule): """REINFORCE algorithm, also known as policy gradients. See superclass `RL4COLitModule` for more details. Args: env: Environment to use for the algorithm policy: Policy to use for the algorithm baseline: REINFORCE baseline baseline_kwargs: Keyword arguments for baseline. Ignored if baseline is not a string **kwargs: Keyword arguments passed to the superclass """ def __init__( self, env: RL4COEnvBase, policy: nn.Module, baseline: Union[REINFORCEBaseline, str] = "rollout", baseline_kwargs: dict = {}, **kwargs, ): super().__init__(env, policy, **kwargs) self.save_hyperparameters(logger=False) if isinstance(baseline, str): baseline = get_reinforce_baseline(baseline, **baseline_kwargs) else: if baseline_kwargs != {}: log.warning("baseline_kwargs is ignored when baseline is not a string") self.baseline = baseline
[docs] def shared_step(self, batch: Any, batch_idx: int, phase: str): td = self.env.reset(batch) # Perform forward pass (i.e., constructing solution and computing log-likelihoods) out = self.policy(td, self.env, phase=phase) # Compute loss if phase == "train": out = self.calculate_loss(td, batch, out) metrics = self.log_metrics(out, phase) return {"loss": out.get("loss", None), **metrics}
[docs] def calculate_loss( self, td: TensorDict, batch: TensorDict, policy_out: dict, reward: Optional[torch.Tensor] = None, log_likelihood: Optional[torch.Tensor] = None, ): """Calculate loss for REINFORCE algorithm. Args: td: TensorDict containing the current state of the environment batch: Batch of data. This is used to get the extra loss terms, e.g., REINFORCE baseline policy_out: Output of the policy network reward: Reward tensor. If None, it is taken from `policy_out` log_likelihood: Log-likelihood tensor. If None, it is taken from `policy_out` """ # Extra: this is used for additional loss terms, e.g., REINFORCE baseline extra = batch.get("extra", None) reward = reward if reward is not None else policy_out["reward"] log_likelihood = ( log_likelihood if log_likelihood is not None else policy_out["log_likelihood"] ) # REINFORCE baseline bl_val, bl_loss = ( self.baseline.eval(td, reward, self.env) if extra is None else (extra, 0) ) # Main loss function advantage = reward - bl_val # advantage = reward - baseline reinforce_loss = -(advantage * log_likelihood).mean() loss = reinforce_loss + bl_loss policy_out.update( { "loss": loss, "reinforce_loss": reinforce_loss, "bl_loss": bl_loss, "bl_val": bl_val, } ) return policy_out
[docs] def post_setup_hook(self, stage="fit"): # Make baseline taking model itself and train_dataloader from model as input self.baseline.setup( self.policy, self.env, batch_size=self.val_batch_size, device=get_lightning_device(self), dataset_size=self.data_cfg["val_data_size"], )
[docs] def on_train_epoch_end(self): """Callback for end of training epoch: we evaluate the baseline""" self.baseline.epoch_callback( self.policy, env=self.env, batch_size=self.val_batch_size, device=get_lightning_device(self), epoch=self.current_epoch, dataset_size=self.data_cfg["val_data_size"], ) # Need to call super() for the dataset to be reset super().on_train_epoch_end()
[docs] def wrap_dataset(self, dataset): """Wrap dataset from baseline evaluation. Used in greedy rollout baseline""" return self.baseline.wrap_dataset( dataset, self.env, batch_size=self.val_batch_size, device=get_lightning_device(self), )
[docs] def set_decode_type_multistart(self, phase: str): """Set decode type to `multistart` for train, val and test in policy. For example, if the decode type is `greedy`, it will be set to `greedy_multistart`. Args: phase: Phase to set decode type for. Must be one of `train`, `val` or `test`. """ attribute = f"{phase}_decode_type" attr_get = getattr(self.policy, attribute) # If does not exist, log error if attr_get is None: log.error(f"Decode type for {phase} is None. Cannot add `_multistart`.") return elif "multistart" in attr_get: return else: setattr(self.policy, attribute, f"{attr_get}_multistart")
[docs] @classmethod def load_from_checkpoint( cls, checkpoint_path: Union[_PATH, IO], map_location: _MAP_LOCATION_TYPE = None, hparams_file: Optional[_PATH] = None, strict: bool = False, load_baseline: bool = True, **kwargs: Any, ) -> Self: """Load model from checkpoint/ Note: This is a modified version of `load_from_checkpoint` from `pytorch_lightning.core.saving`. It deals with matching keys for the baseline by first running setup """ if strict: log.warning("Setting strict=False for loading model from checkpoint.") strict = False # Do not use strict loaded = _load_from_checkpoint( cls, checkpoint_path, map_location, hparams_file, strict, **kwargs, ) # Load baseline state dict if load_baseline: # setup baseline first loaded.setup() loaded.post_setup_hook() # load baseline state dict state_dict = torch.load(checkpoint_path)["state_dict"] # get only baseline parameters state_dict = {k: v for k, v in state_dict.items() if "baseline" in k} state_dict = {k.replace("baseline.", "", 1): v for k, v in state_dict.items()} loaded.baseline.load_state_dict(state_dict) return cast(Self, loaded)

© Copyright Federico Berto, Chuanbo Hua, Junyoung Park. Revision f4bc96ca.

Built with Sphinx using a theme provided by Read the Docs.