Shortcuts

Tasks: Train and Evaluate

Train

rl4co.tasks.train.train(cfg)[source]
Return type:

Optional[float]

Evaluate

class rl4co.tasks.eval.AugmentationEval(env, num_augment=8, force_dihedral_8=False, **kwargs)[source]

Bases: EvalBase

Evaluates the policy via N state augmentations force_dihedral_8 forces the use of 8 augmentations (rotations and flips) as in POMO https://en.wikipedia.org/wiki/Examples_of_groups#dihedral_group_of_order_8

Parameters:
  • num_augment (int) – Number of state augmentations

  • force_dihedral_8 (bool) – Whether to force the use of 8 augmentations

name = 'augmentation'
property num_augment
class rl4co.tasks.eval.EvalBase(env, progress=True, **kwargs)[source]

Bases: object

Base class for evaluation

Parameters:
  • env – Environment

  • progress – Whether to show progress bar

  • **kwargs – Additional arguments (to be implemented in subclasses)

__call__(policy, dataloader, **kwargs)[source]

Evaluate the policy on the given dataloader with **kwargs parameter self._inner is implemented in subclasses and returns actions and rewards

name = 'base'
class rl4co.tasks.eval.GreedyEval(env, **kwargs)[source]

Bases: EvalBase

Evaluates the policy using greedy decoding and single trajectory

name = 'greedy'
class rl4co.tasks.eval.GreedyMultiStartAugmentEval(env, num_starts=None, num_augment=8, force_dihedral_8=False, **kwargs)[source]

Bases: EvalBase

Evaluates the policy via num_starts samples from the policy and num_augment augmentations of each sample.` force_dihedral_8 forces the use of 8 augmentations (rotations and flips) as in POMO https://en.wikipedia.org/wiki/Examples_of_groups#dihedral_group_of_order_8

Parameters:
  • num_starts – Number of greedy multistart samples

  • num_augment – Number of augmentations per sample

  • force_dihedral_8 – If True, force the use of 8 augmentations (rotations and flips) as in POMO

name = 'greedy_multistart_augment'
property num_augment
class rl4co.tasks.eval.GreedyMultiStartEval(env, num_starts=None, **kwargs)[source]

Bases: EvalBase

Evaluates the policy via num_starts greedy multistarts samples from the policy

Parameters:

num_starts (int) – Number of greedy multistarts to use

name = 'greedy_multistart'
class rl4co.tasks.eval.SamplingEval(env, samples, softmax_temp=None, **kwargs)[source]

Bases: EvalBase

Evaluates the policy via N samples from the policy

Parameters:
  • samples (int) – Number of samples to take

  • softmax_temp (float) – Temperature for softmax sampling. The higher the temperature, the more random the sampling

name = 'sampling'
rl4co.tasks.eval.check_unused_kwargs(class_, kwargs)[source]
rl4co.tasks.eval.evaluate_policy(env, policy, dataset, method='greedy', batch_size=None, max_batch_size=4096, start_batch_size=8192, auto_batch_size=True, save_results=False, save_fname='results.npz', **kwargs)[source]
rl4co.tasks.eval.get_automatic_batch_size(eval_fn, start_batch_size=8192, max_batch_size=4096)[source]

Automatically reduces the batch size based on the eval function

Parameters:
  • eval_fn – The eval function

  • start_batch_size – The starting batch size. This should be the theoretical maximum batch size

  • max_batch_size – The maximum batch size. This is the practical maximum batch size