Shortcuts

Reinforce

REINFORCE

class rl4co.models.rl.reinforce.reinforce.REINFORCE(env, policy, baseline='rollout', baseline_kwargs={}, **kwargs)[source]

Bases: RL4COLitModule

REINFORCE algorithm, also known as policy gradients. See superclass RL4COLitModule for more details.

Parameters:
  • env (RL4COEnvBase) – Environment to use for the algorithm

  • policy (Module) – Policy to use for the algorithm

  • baseline (Union[REINFORCEBaseline, str]) – REINFORCE baseline

  • baseline_kwargs (dict) – Keyword arguments for baseline. Ignored if baseline is not a string

  • **kwargs – Keyword arguments passed to the superclass

calculate_loss(td, batch, policy_out, reward=None, log_likelihood=None)[source]

Calculate loss for REINFORCE algorithm.

Parameters:
  • td (TensorDict) – TensorDict containing the current state of the environment

  • batch (TensorDict) – Batch of data. This is used to get the extra loss terms, e.g., REINFORCE baseline

  • policy_out (dict) – Output of the policy network

  • reward (Optional[Tensor]) – Reward tensor. If None, it is taken from policy_out

  • log_likelihood (Optional[Tensor]) – Log-likelihood tensor. If None, it is taken from policy_out

classmethod load_from_checkpoint(checkpoint_path, map_location=None, hparams_file=None, strict=False, load_baseline=True, **kwargs)[source]

Load model from checkpoint/ :rtype: Self

Note

This is a modified version of load_from_checkpoint from pytorch_lightning.core.saving. It deals with matching keys for the baseline by first running setup

on_train_epoch_end()[source]

Callback for end of training epoch: we evaluate the baseline

post_setup_hook(stage='fit')[source]

Hook to be called after setup. Can be used to set up subclasses without overriding setup

set_decode_type_multistart(phase)[source]

Set decode type to multistart for train, val and test in policy. For example, if the decode type is greedy, it will be set to greedy_multistart.

Parameters:

phase (str) – Phase to set decode type for. Must be one of train, val or test.

shared_step(batch, batch_idx, phase, dataloader_idx=None)[source]

Shared step between train/val/test. To be implemented in subclass

wrap_dataset(dataset)[source]

Wrap dataset from baseline evaluation. Used in greedy rollout baseline


Baselines

class rl4co.models.rl.reinforce.baselines.CriticBaseline(critic=None, **unused_kw)[source]

Bases: REINFORCEBaseline

Critic baseline: use critic network as baseline

Parameters:

critic (Module) – Critic network to use as baseline. If None, create a new critic network based on the environment

Initializes internal Module state, shared by both nn.Module and ScriptModule.

eval(x, c, env=None)[source]

Evaluate baseline

setup(model, env, **kwargs)[source]

To be called before training during setup phase This follow PyTorch Lightning’s setup() convention

class rl4co.models.rl.reinforce.baselines.ExponentialBaseline(beta=0.8, **kw)[source]

Bases: REINFORCEBaseline

Exponential baseline: return exponential moving average of reward as baseline

Parameters:

beta – Beta value for the exponential moving average

Initializes internal Module state, shared by both nn.Module and ScriptModule.

eval(td, reward, env=None)[source]

Evaluate baseline

class rl4co.models.rl.reinforce.baselines.MeanBaseline(**kw)[source]

Bases: REINFORCEBaseline

Mean baseline: return mean of reward as baseline

Initializes internal Module state, shared by both nn.Module and ScriptModule.

class rl4co.models.rl.reinforce.baselines.NoBaseline(*args, **kw)[source]

Bases: REINFORCEBaseline

No baseline: return 0 for baseline and neg_los

Initializes internal Module state, shared by both nn.Module and ScriptModule.

eval(td, reward, env=None)[source]

Evaluate baseline

class rl4co.models.rl.reinforce.baselines.REINFORCEBaseline(*args, **kw)[source]

Bases: Module

Base class for REINFORCE baselines

Initializes internal Module state, shared by both nn.Module and ScriptModule.

epoch_callback(*args, **kw)[source]

Callback at the end of each epoch For example, update baseline parameters and obtain baseline values

eval(td, reward, env=None)[source]

Evaluate baseline

setup(*args, **kw)[source]

To be called before training during setup phase This follow PyTorch Lightning’s setup() convention

wrap_dataset(dataset, *args, **kw)[source]

Wrap dataset with baseline-specific functionality

class rl4co.models.rl.reinforce.baselines.RolloutBaseline(bl_alpha=0.05, **kw)[source]

Bases: REINFORCEBaseline

Rollout baseline: use greedy rollout as baseline

Parameters:

bl_alpha – Alpha value for the baseline T-test

Initializes internal Module state, shared by both nn.Module and ScriptModule.

epoch_callback(model, env, batch_size=64, device='cpu', epoch=None, dataset_size=None)[source]

Challenges the current baseline with the model and replaces the baseline model if it is improved

eval(td, reward, env)[source]

Evaluate rollout baseline

Warning

This is not differentiable and should only be used for evaluation. Also, it is recommended to use the rollout method directly instead of this method.

rollout(model, env, batch_size=64, device='cpu', dataset=None)[source]

Rollout the model on the given dataset

setup(*args, **kw)[source]

To be called before training during setup phase This follow PyTorch Lightning’s setup() convention

wrap_dataset(dataset, env, batch_size=64, device='cpu', **kw)[source]

Wrap the dataset in a baseline dataset

Note

This is an alternative to eval that does not require the model to be passed at every call but just once. Values are added to the dataset. This also allows for larger batch sizes since we evauate the model without gradients.

class rl4co.models.rl.reinforce.baselines.SharedBaseline(*args, **kw)[source]

Bases: REINFORCEBaseline

Shared baseline: return mean of reward as baseline

Initializes internal Module state, shared by both nn.Module and ScriptModule.

eval(td, reward, env=None, on_dim=1)[source]

Evaluate baseline

class rl4co.models.rl.reinforce.baselines.WarmupBaseline(baseline, n_epochs=1, warmup_exp_beta=0.8, **kw)[source]

Bases: REINFORCEBaseline

Warmup baseline: return convex combination of baseline and exponential baseline

Parameters:
  • baseline – Baseline to use after warmup

  • n_epochs – Number of epochs to warmup

  • warmup_exp_beta – Beta value for the exponential baseline during warmup

Initializes internal Module state, shared by both nn.Module and ScriptModule.

epoch_callback(*args, **kw)[source]

Callback at the end of each epoch For example, update baseline parameters and obtain baseline values

eval(td, reward, env=None)[source]

Evaluate baseline

setup(*args, **kw)[source]

To be called before training during setup phase This follow PyTorch Lightning’s setup() convention

wrap_dataset(dataset, *args, **kw)[source]

Wrap dataset with baseline-specific functionality

rl4co.models.rl.reinforce.baselines.get_reinforce_baseline(name, **kw)[source]

Get a REINFORCE baseline by name The rollout baseline default to warmup baseline with one epoch of exponential baseline and the greedy rollout