Shortcuts

Source code for rl4co.tasks.eval

import numpy as np
import torch

from torch.utils.data import DataLoader
from tqdm.auto import tqdm

from rl4co.data.dataset import tensordict_collate_fn
from rl4co.data.transforms import StateAugmentation
from rl4co.utils.ops import batchify, gather_by_index, unbatchify


[docs]def check_unused_kwargs(class_, kwargs): if len(kwargs) > 0 and not (len(kwargs) == 1 and "progress" in kwargs): print(f"Warning: {class_.__class__.__name__} does not use kwargs {kwargs}")
[docs]class EvalBase: """Base class for evaluation Args: env: Environment progress: Whether to show progress bar **kwargs: Additional arguments (to be implemented in subclasses) """ name = "base" def __init__(self, env, progress=True, **kwargs): check_unused_kwargs(self, kwargs) self.env = env self.progress = progress
[docs] def __call__(self, policy, dataloader, **kwargs): """Evaluate the policy on the given dataloader with **kwargs parameter self._inner is implemented in subclasses and returns actions and rewards """ # Collect timings for evaluation (more accurate than timeit) start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() with torch.no_grad(): rewards_list = [] actions_list = [] for batch in tqdm( dataloader, disable=not self.progress, desc=f"Running {self.name}" ): td = batch.to(next(policy.parameters()).device) td = self.env.reset(td) actions, rewards = self._inner(policy, td, **kwargs) rewards_list.append(rewards) actions_list.append(actions) rewards = torch.cat(rewards_list) # Padding: pad actions to the same length with zeros max_length = max(action.size(-1) for action in actions) actions = torch.cat( [ torch.nn.functional.pad(action, (0, max_length - action.size(-1))) for action in actions ], 0, ) end_event.record() torch.cuda.synchronize() inference_time = start_event.elapsed_time(end_event) tqdm.write(f"Mean reward for {self.name}: {rewards.mean():.4f}") tqdm.write(f"Time: {inference_time/1000:.4f}s") # Empty cache torch.cuda.empty_cache() return { "actions": actions.cpu(), "rewards": rewards.cpu(), "inference_time": inference_time, "avg_reward": rewards.cpu().mean(), }
def _inner(self, policy, td): """Inner function to be implemented in subclasses. This function returns actions and rewards for the given policy """ raise NotImplementedError("Implement in subclass")
[docs]class GreedyEval(EvalBase): """Evaluates the policy using greedy decoding and single trajectory""" name = "greedy" def __init__(self, env, **kwargs): check_unused_kwargs(self, kwargs) super().__init__(env, kwargs.get("progress", True)) def _inner(self, policy, td): out = policy( td.clone(), decode_type="greedy", num_starts=0, return_actions=True, ) rewards = self.env.get_reward(td, out["actions"]) return out["actions"], rewards
[docs]class AugmentationEval(EvalBase): """Evaluates the policy via N state augmentations `force_dihedral_8` forces the use of 8 augmentations (rotations and flips) as in POMO https://en.wikipedia.org/wiki/Examples_of_groups#dihedral_group_of_order_8 Args: num_augment (int): Number of state augmentations force_dihedral_8 (bool): Whether to force the use of 8 augmentations """ name = "augmentation" def __init__(self, env, num_augment=8, force_dihedral_8=False, **kwargs): check_unused_kwargs(self, kwargs) super().__init__(env, kwargs.get("progress", True)) self.augmentation = StateAugmentation( env.name, num_augment=num_augment, use_dihedral_8=force_dihedral_8 ) def _inner(self, policy, td, num_augment=None): if num_augment is None: num_augment = self.augmentation.num_augment td_init = td.clone() td = self.augmentation(td) out = policy(td.clone(), decode_type="greedy", num_starts=0, return_actions=True) # Move into batches and compute rewards rewards = self.env.get_reward(batchify(td_init, num_augment), out["actions"]) rewards = unbatchify(rewards, num_augment) actions = unbatchify(out["actions"], num_augment) # Get best reward and corresponding action rewards, max_idxs = rewards.max(dim=1) actions = gather_by_index(actions, max_idxs, dim=1) return actions, rewards @property def num_augment(self): return self.augmentation.num_augment
[docs]class SamplingEval(EvalBase): """Evaluates the policy via N samples from the policy Args: samples (int): Number of samples to take softmax_temp (float): Temperature for softmax sampling. The higher the temperature, the more random the sampling """ name = "sampling" def __init__(self, env, samples, softmax_temp=None, **kwargs): check_unused_kwargs(self, kwargs) super().__init__(env, kwargs.get("progress", True)) self.samples = samples self.softmax_temp = softmax_temp def _inner(self, policy, td): td = batchify(td, self.samples) out = policy( td.clone(), decode_type="sampling", num_starts=0, return_actions=True, softmax_temp=self.softmax_temp, ) # Move into batches and compute rewards rewards = self.env.get_reward(td, out["actions"]) rewards = unbatchify(rewards, self.samples) actions = unbatchify(out["actions"], self.samples) # Get the best reward and action for each sample rewards, max_idxs = rewards.max(dim=1) actions = gather_by_index(actions, max_idxs, dim=1) return actions, rewards
[docs]class GreedyMultiStartEval(EvalBase): """Evaluates the policy via `num_starts` greedy multistarts samples from the policy Args: num_starts (int): Number of greedy multistarts to use """ name = "greedy_multistart" def __init__(self, env, num_starts=None, **kwargs): check_unused_kwargs(self, kwargs) super().__init__(env, kwargs.get("progress", True)) assert num_starts is not None, "Must specify num_starts" self.num_starts = num_starts def _inner(self, policy, td): td_init = td.clone() out = policy( td.clone(), decode_type="greedy_multistart", num_starts=self.num_starts, return_actions=True, ) # Move into batches and compute rewards td = batchify(td_init, self.num_starts) rewards = self.env.get_reward(td, out["actions"]) rewards = unbatchify(rewards, self.num_starts) actions = unbatchify(out["actions"], self.num_starts) # Get the best trajectories rewards, max_idxs = rewards.max(dim=1) actions = gather_by_index(actions, max_idxs, dim=1) return actions, rewards
[docs]class GreedyMultiStartAugmentEval(EvalBase): """Evaluates the policy via `num_starts` samples from the policy and `num_augment` augmentations of each sample.` `force_dihedral_8` forces the use of 8 augmentations (rotations and flips) as in POMO https://en.wikipedia.org/wiki/Examples_of_groups#dihedral_group_of_order_8 Args: num_starts: Number of greedy multistart samples num_augment: Number of augmentations per sample force_dihedral_8: If True, force the use of 8 augmentations (rotations and flips) as in POMO """ name = "greedy_multistart_augment" def __init__( self, env, num_starts=None, num_augment=8, force_dihedral_8=False, **kwargs ): check_unused_kwargs(self, kwargs) super().__init__(env, kwargs.get("progress", True)) assert num_starts is not None, "Must specify num_starts" self.num_starts = num_starts assert not ( num_augment != 8 and force_dihedral_8 ), "Cannot force dihedral 8 when num_augment != 8" self.augmentation = StateAugmentation( env.name, num_augment=num_augment, use_dihedral_8=force_dihedral_8 ) def _inner(self, policy, td, num_augment=None): if num_augment is None: num_augment = self.augmentation.num_augment td_init = td.clone() td = self.augmentation(td) out = policy( td.clone(), decode_type="greedy_multistart", num_starts=self.num_starts, return_actions=True, ) # Move into batches and compute rewards td = batchify(td_init, (num_augment, self.num_starts)) rewards = self.env.get_reward(td, out["actions"]) rewards = unbatchify(rewards, self.num_starts * num_augment) actions = unbatchify(out["actions"], self.num_starts * num_augment) # Get the best trajectories rewards, max_idxs = rewards.max(dim=1) actions = gather_by_index(actions, max_idxs, dim=1) return actions, rewards @property def num_augment(self): return self.augmentation.num_augment
[docs]def get_automatic_batch_size(eval_fn, start_batch_size=8192, max_batch_size=4096): """Automatically reduces the batch size based on the eval function Args: eval_fn: The eval function start_batch_size: The starting batch size. This should be the theoretical maximum batch size max_batch_size: The maximum batch size. This is the practical maximum batch size """ batch_size = start_batch_size effective_ratio = 1 if hasattr(eval_fn, "num_starts"): batch_size = batch_size // (eval_fn.num_starts // 10) effective_ratio *= eval_fn.num_starts // 10 if hasattr(eval_fn, "num_augment"): batch_size = batch_size // eval_fn.num_augment effective_ratio *= eval_fn.num_augment if hasattr(eval_fn, "samples"): batch_size = batch_size // eval_fn.samples effective_ratio *= eval_fn.samples batch_size = min(batch_size, max_batch_size) # get closest integer power of 2 batch_size = 2 ** int(np.log2(batch_size)) print(f"Effective batch size: {batch_size} (ratio: {effective_ratio})") return batch_size
[docs]def evaluate_policy( env, policy, dataset, method="greedy", batch_size=None, max_batch_size=4096, start_batch_size=8192, auto_batch_size=True, save_results=False, save_fname="results.npz", **kwargs, ): num_loc = getattr(env, "num_loc", None) methods_mapping = { "greedy": {"func": GreedyEval, "kwargs": {}}, "sampling": { "func": SamplingEval, "kwargs": {"samples": 100, "softmax_temp": 1.0}, }, "greedy_multistart": { "func": GreedyMultiStartEval, "kwargs": {"num_starts": num_loc}, }, "augment_dihedral_8": { "func": AugmentationEval, "kwargs": {"num_augment": 8, "force_dihedral_8": True}, }, "augment": {"func": AugmentationEval, "kwargs": {"num_augment": 8}}, "greedy_multistart_augment_dihedral_8": { "func": GreedyMultiStartAugmentEval, "kwargs": { "num_augment": 8, "force_dihedral_8": True, "num_starts": num_loc, }, }, "greedy_multistart_augment": { "func": GreedyMultiStartAugmentEval, "kwargs": {"num_augment": 8, "num_starts": num_loc}, }, } assert method in methods_mapping, "Method {} not found".format(method) # Set up the evaluation function eval_settings = methods_mapping[method] func, kwargs_ = eval_settings["func"], eval_settings["kwargs"] # subsitute kwargs with the ones passed in kwargs_.update(kwargs) kwargs = kwargs_ eval_fn = func(env, **kwargs) if auto_batch_size: assert ( batch_size is None ), "Cannot specify batch_size when auto_batch_size is True" batch_size = get_automatic_batch_size( eval_fn, max_batch_size=max_batch_size, start_batch_size=start_batch_size ) print("Using automatic batch size: {}".format(batch_size)) # Set up the dataloader dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=tensordict_collate_fn, ) # Run evaluation retvals = eval_fn(policy, dataloader) # Save results if save_results: print("Saving results to {}".format(save_fname)) np.savez(save_fname, **retvals) return retvals

© Copyright Federico Berto, Chuanbo Hua, Junyoung Park. Revision f4bc96ca.

Built with Sphinx using a theme provided by Read the Docs.