Tasks: Train and Evaluate¶
Train¶
Evaluate¶
- class rl4co.tasks.eval.AugmentationEval(env, num_augment=8, force_dihedral_8=False, **kwargs)[source]¶
Bases:
EvalBaseEvaluates the policy via N state augmentations force_dihedral_8 forces the use of 8 augmentations (rotations and flips) as in POMO https://en.wikipedia.org/wiki/Examples_of_groups#dihedral_group_of_order_8
- Parameters:
- name = 'augmentation'¶
- property num_augment¶
- class rl4co.tasks.eval.EvalBase(env, progress=True, **kwargs)[source]¶
Bases:
objectBase class for evaluation
- Parameters:
- __call__(policy, dataloader, **kwargs)[source]¶
Evaluate the policy on the given dataloader with **kwargs parameter self._inner is implemented in subclasses and returns actions and rewards
- name = 'base'¶
- class rl4co.tasks.eval.GreedyEval(env, **kwargs)[source]¶
Bases:
EvalBaseEvaluates the policy using greedy decoding and single trajectory
- name = 'greedy'¶
- class rl4co.tasks.eval.GreedyMultiStartAugmentEval(env, num_starts=None, num_augment=8, force_dihedral_8=False, **kwargs)[source]¶
Bases:
EvalBaseEvaluates the policy via num_starts samples from the policy and num_augment augmentations of each sample.` force_dihedral_8 forces the use of 8 augmentations (rotations and flips) as in POMO https://en.wikipedia.org/wiki/Examples_of_groups#dihedral_group_of_order_8
- Parameters:
- name = 'greedy_multistart_augment'¶
- property num_augment¶
- class rl4co.tasks.eval.GreedyMultiStartEval(env, num_starts=None, **kwargs)[source]¶
Bases:
EvalBaseEvaluates the policy via num_starts greedy multistarts samples from the policy
- Parameters:
num_starts¶ (int) – Number of greedy multistarts to use
- name = 'greedy_multistart'¶
- class rl4co.tasks.eval.SamplingEval(env, samples, softmax_temp=None, **kwargs)[source]¶
Bases:
EvalBaseEvaluates the policy via N samples from the policy
- Parameters:
- name = 'sampling'¶
- rl4co.tasks.eval.evaluate_policy(env, policy, dataset, method='greedy', batch_size=None, max_batch_size=4096, start_batch_size=8192, auto_batch_size=True, save_results=False, save_fname='results.npz', **kwargs)[source]¶