Reinforce¶
REINFORCE¶
- class rl4co.models.rl.reinforce.reinforce.REINFORCE(env, policy, baseline='rollout', baseline_kwargs={}, **kwargs)[source]¶
Bases:
RL4COLitModuleREINFORCE algorithm, also known as policy gradients. See superclass RL4COLitModule for more details.
- Parameters:
env¶ (
RL4COEnvBase) – Environment to use for the algorithmpolicy¶ (
Module) – Policy to use for the algorithmbaseline¶ (
Union[REINFORCEBaseline,str]) – REINFORCE baselinebaseline_kwargs¶ (
dict) – Keyword arguments for baseline. Ignored if baseline is not a string**kwargs¶ – Keyword arguments passed to the superclass
- calculate_loss(td, batch, policy_out, reward=None, log_likelihood=None)[source]¶
Calculate loss for REINFORCE algorithm.
- Parameters:
td¶ (
TensorDict) – TensorDict containing the current state of the environmentbatch¶ (
TensorDict) – Batch of data. This is used to get the extra loss terms, e.g., REINFORCE baselinepolicy_out¶ (
dict) – Output of the policy networkreward¶ (
Optional[Tensor]) – Reward tensor. If None, it is taken from policy_outlog_likelihood¶ (
Optional[Tensor]) – Log-likelihood tensor. If None, it is taken from policy_out
- classmethod load_from_checkpoint(checkpoint_path, map_location=None, hparams_file=None, strict=False, load_baseline=True, **kwargs)[source]¶
Load model from checkpoint/ :rtype:
SelfNote
This is a modified version of load_from_checkpoint from pytorch_lightning.core.saving. It deals with matching keys for the baseline by first running setup
- post_setup_hook(stage='fit')[source]¶
Hook to be called after setup. Can be used to set up subclasses without overriding setup
- set_decode_type_multistart(phase)[source]¶
Set decode type to multistart for train, val and test in policy. For example, if the decode type is greedy, it will be set to greedy_multistart.
- Parameters:
phase¶ (
str) – Phase to set decode type for. Must be one of train, val or test.
Shared step between train/val/test. To be implemented in subclass
Baselines¶
- class rl4co.models.rl.reinforce.baselines.CriticBaseline(critic=None, **unused_kw)[source]¶
Bases:
REINFORCEBaselineCritic baseline: use critic network as baseline
- Parameters:
critic¶ (
Module) – Critic network to use as baseline. If None, create a new critic network based on the environment
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- class rl4co.models.rl.reinforce.baselines.ExponentialBaseline(beta=0.8, **kw)[source]¶
Bases:
REINFORCEBaselineExponential baseline: return exponential moving average of reward as baseline
- Parameters:
beta¶ – Beta value for the exponential moving average
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- class rl4co.models.rl.reinforce.baselines.MeanBaseline(**kw)[source]¶
Bases:
REINFORCEBaselineMean baseline: return mean of reward as baseline
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- class rl4co.models.rl.reinforce.baselines.NoBaseline(*args, **kw)[source]¶
Bases:
REINFORCEBaselineNo baseline: return 0 for baseline and neg_los
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- class rl4co.models.rl.reinforce.baselines.REINFORCEBaseline(*args, **kw)[source]¶
Bases:
ModuleBase class for REINFORCE baselines
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- epoch_callback(*args, **kw)[source]¶
Callback at the end of each epoch For example, update baseline parameters and obtain baseline values
- class rl4co.models.rl.reinforce.baselines.RolloutBaseline(bl_alpha=0.05, **kw)[source]¶
Bases:
REINFORCEBaselineRollout baseline: use greedy rollout as baseline
- Parameters:
bl_alpha¶ – Alpha value for the baseline T-test
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- epoch_callback(model, env, batch_size=64, device='cpu', epoch=None, dataset_size=None)[source]¶
Challenges the current baseline with the model and replaces the baseline model if it is improved
- eval(td, reward, env)[source]¶
Evaluate rollout baseline
Warning
This is not differentiable and should only be used for evaluation. Also, it is recommended to use the rollout method directly instead of this method.
- rollout(model, env, batch_size=64, device='cpu', dataset=None)[source]¶
Rollout the model on the given dataset
- setup(*args, **kw)[source]¶
To be called before training during setup phase This follow PyTorch Lightning’s setup() convention
- wrap_dataset(dataset, env, batch_size=64, device='cpu', **kw)[source]¶
Wrap the dataset in a baseline dataset
Note
This is an alternative to eval that does not require the model to be passed at every call but just once. Values are added to the dataset. This also allows for larger batch sizes since we evauate the model without gradients.
Bases:
REINFORCEBaselineShared baseline: return mean of reward as baseline
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Evaluate baseline
- class rl4co.models.rl.reinforce.baselines.WarmupBaseline(baseline, n_epochs=1, warmup_exp_beta=0.8, **kw)[source]¶
Bases:
REINFORCEBaselineWarmup baseline: return convex combination of baseline and exponential baseline
- Parameters:
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- epoch_callback(*args, **kw)[source]¶
Callback at the end of each epoch For example, update baseline parameters and obtain baseline values