Shortcuts

Source code for rl4co.models.zoo.eas.search

import time

from functools import partial
from typing import Any, List, Union

import torch

from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset

from rl4co.data.transforms import StateAugmentation
from rl4co.models.nn.utils import get_log_likelihood
from rl4co.models.zoo.common.search import SearchBase
from rl4co.models.zoo.eas.decoder import forward_eas, forward_logit_attn_eas_lay
from rl4co.models.zoo.eas.nn import EASLayerNet
from rl4co.utils.ops import batchify, gather_by_index, get_num_starts, unbatchify
from rl4co.utils.pylogger import get_pylogger

log = get_pylogger(__name__)


[docs]class EAS(SearchBase): """Efficient Active Search for Neural Combination Optimization from Hottung et al. (2022). Fine-tunes a subset of parameters (such as node embeddings or newly added layers) thus avoiding expensive re-encoding of the problem. Reference: https://openreview.net/pdf?id=nO5caZwFwYu Args: env: RL4CO environment to be solved policy: policy network dataset: dataset to be used for training use_eas_embedding: whether to use EAS embedding (EASEmb) use_eas_layer: whether to use EAS layer (EASLay) eas_emb_cache_keys: keys to cache in the embedding eas_lambda: lambda parameter for IL loss batch_size: batch size for training max_iters: maximum number of iterations augment_size: number of augmentations per state augment_dihedral: whether to augment with dihedral rotations parallel_runs: number of parallel runs baseline: REINFORCE baseline type (multistart, symmetric, full) max_runtime: maximum runtime in seconds save_path: path to save solution checkpoints optimizer: optimizer to use for training optimizer_kwargs: keyword arguments for optimizer verbose: whether to print progress for each iteration """ def __init__( self, env, policy, dataset: Union[Dataset, str], use_eas_embedding: bool = True, use_eas_layer: bool = False, eas_emb_cache_keys: List[str] = ["logit_key"], eas_lambda: float = 0.013, batch_size: int = 2, max_iters: int = 200, augment_size: int = 8, augment_dihedral: bool = True, num_parallel_runs: int = 1, baseline: str = "multistart", max_runtime: int = 86_400, save_path: str = None, optimizer: Union[str, torch.optim.Optimizer, partial] = "Adam", optimizer_kwargs: dict = {"lr": 0.0041, "weight_decay": 1e-6}, verbose: bool = True, **kwargs, ): self.save_hyperparameters(logger=False) assert ( self.hparams.use_eas_embedding or self.hparams.use_eas_layer ), "At least one of `use_eas_embedding` or `use_eas_layer` must be True." super(EAS, self).__init__( env, policy=policy, dataset=dataset, batch_size=batch_size, max_iters=max_iters, max_runtime=max_runtime, save_path=save_path, optimizer=optimizer, optimizer_kwargs=optimizer_kwargs, **kwargs, ) assert self.hparams.baseline in [ "multistart", "symmetric", "full", ], f"Baseline {self.hparams.baseline} not supported."
[docs] def setup(self, stage="fit"): """Setup base class and instantiate: - augmentation - instance solutions and rewards - original policy state dict """ log.info( f"Setting up Efficient Active Search (EAS) with: \n" f"- EAS Embedding: {self.hparams.use_eas_embedding} \n" f"- EAS Layer: {self.hparams.use_eas_layer} \n" ) super(EAS, self).setup(stage) # Instantiate augmentation self.augmentation = StateAugmentation( self.env.name, num_augment=self.hparams.augment_size, use_dihedral_8=self.hparams.augment_dihedral, ) # Store original policy state dict self.original_policy_state = self.policy.state_dict() # Get dataset size and problem size len(self.dataset) _batch = next(iter(self.train_dataloader())) self.problem_size = self.env.reset(_batch)["action_mask"].shape[-1] self.instance_solutions = [] self.instance_rewards = []
[docs] def on_train_batch_start(self, batch: Any, batch_idx: int): """Called before training (i.e. search) for a new batch begins. We re-load the original policy state dict and configure all parameters not to require gradients. We do the rest in the training step. """ self.policy.load_state_dict(self.original_policy_state) # Set all policy parameters to not require gradients for param in self.policy.parameters(): param.requires_grad = False
[docs] def training_step(self, batch, batch_idx): """Main search loop. We use the training step to effectively adapt to a `batch` of instances.""" # Augment state batch_size = batch.shape[0] td_init = self.env.reset(batch) n_aug, n_start, n_runs = ( self.augmentation.num_augment, get_num_starts(td_init), self.hparams.num_parallel_runs, ) td_init = self.augmentation(td_init) td_init = batchify(td_init, n_runs) num_instances = batch_size * n_aug * n_runs # NOTE: no num_starts! # batch_r = n_runs * batch_size # effective batch size group_s = ( n_start + 1 ) # number of different rollouts per instance (+1 for incumbent solution construction) # Get encoder and decoder for simplicity encoder = self.policy.encoder decoder = self.policy.decoder # Precompute the cache of the embeddings (i.e. q,k,v and logit_key) embeddings, _ = encoder(td_init) cached_embeds = decoder._precompute_cache(embeddings, num_starts=group_s) # Collect optimizer parameters opt_params = [] if self.hparams.use_eas_layer: # EASLay: replace forward of logit attention computation. EASLayer eas_layer = EASLayerNet(num_instances, decoder.embedding_dim).to(batch.device) decoder.logit_attention.eas_layer = partial( eas_layer, decoder.logit_attention ) decoder.logit_attention.forward = partial( forward_logit_attn_eas_lay, decoder.logit_attention ) for param in eas_layer.parameters(): opt_params.append(param) if self.hparams.use_eas_embedding: # EASEmb: set gradient of emb_key to True # for all the keys, wrap the embedding in a nn.Parameter for key in self.hparams.eas_emb_cache_keys: setattr( cached_embeds, key, torch.nn.Parameter(getattr(cached_embeds, key)) ) opt_params.append(getattr(cached_embeds, key)) decoder.forward = partial(forward_eas, decoder) self.configure_optimizers(opt_params) # Solution and reward buffer max_reward = torch.full((batch_size,), -float("inf"), device=batch.device) best_solutions = torch.zeros( batch_size, self.problem_size * 2, device=batch.device, dtype=int ) # i.e. incumbent solutions # Init search t_start = time.time() for iter_count in range(self.hparams.max_iters): # Evaluate policy with sampling multistarts passing the cached embeddings best_solutions_expanded = best_solutions.repeat(n_aug, 1).repeat(n_runs, 1) log_p, actions, td_out, reward = decoder( td_init.clone(), cached_embeds=cached_embeds, best_solutions=best_solutions_expanded, iter_count=iter_count, env=self.env, decode_type="multistart_sampling", num_starts=n_start, ) # Unbatchify to get correct dimensions ll = get_log_likelihood(log_p, actions, td_out.get("mask", None)) ll = unbatchify(ll, (n_runs * batch_size, n_aug, group_s)).squeeze() reward = unbatchify(reward, (n_runs * batch_size, n_aug, group_s)).squeeze() actions = unbatchify(actions, (n_runs * batch_size, n_aug, group_s)).squeeze() # Compute REINFORCE loss with shared baselines # compared to original EAS, we also support symmetric and full baselines group_reward = reward[..., :-1] # exclude incumbent solution if self.hparams.baseline == "multistart": bl_val = group_reward.mean(dim=-1, keepdim=True) elif self.hparams.baseline == "symmetric": bl_val = group_reward.mean(dim=-2, keepdim=True) elif self.hparams.baseline == "full": bl_val = group_reward.mean(dim=-1, keepdim=True).mean( dim=-2, keepdim=True ) else: raise ValueError(f"Baseline {self.hparams.baseline} not supported.") # REINFORCE loss advantage = group_reward - bl_val loss_rl = -(advantage * ll[..., :-1]).mean() # IL loss loss_il = -ll[..., -1].mean() # Total loss loss = loss_rl + self.hparams.eas_lambda * loss_il # Manual backpropagation opt = self.optimizers() opt.zero_grad() self.manual_backward(loss) # Save best solutions and rewards # Get max reward for each group and instance max_reward = reward.max(dim=2)[0].max(dim=1)[0] # Reshape and rank rewards reward_group = reward.reshape(n_runs * batch_size, -1) _, top_indices = torch.topk(reward_group, k=1, dim=1) # Obtain best solutions found so far solutions = actions.reshape(n_runs * batch_size, n_aug * group_s, -1) best_solutions_iter = gather_by_index(solutions, top_indices, dim=1) best_solutions[:, : best_solutions_iter.shape[1]] = best_solutions_iter self.log_dict( { "loss": loss, "max_reward": max_reward.mean(), "step": iter_count, "time": time.time() - t_start, }, on_step=self.log_on_step, ) log.info( f"{iter_count}/{self.hparams.max_iters} | " f" Reward: {max_reward.mean().item():.2f} " ) # Stop if max runtime is exceeded if time.time() - t_start > self.hparams.max_runtime: log.info(f"Max runtime of {self.hparams.max_runtime} seconds exceeded.") break return {"max_reward": max_reward, "best_solutions": best_solutions}
[docs] def on_train_batch_end( self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int ) -> None: """We store the best solution and reward found.""" max_rewards, best_solutions = outputs["max_reward"], outputs["best_solutions"] self.instance_solutions.append(best_solutions) self.instance_rewards.append(max_rewards) log.info(f"Best reward: {max_rewards.mean():.2f}")
[docs] def on_train_epoch_end(self) -> None: """Called when the train ends.""" save_path = self.hparams.save_path # concatenate solutions and rewards self.instance_solutions = pad_sequence( self.instance_solutions, batch_first=True, padding_value=0 ).squeeze() self.instance_rewards = torch.cat(self.instance_rewards, dim=0).squeeze() if save_path is not None: log.info(f"Saving solutions and rewards to {save_path}...") torch.save( {"solutions": self.instance_solutions, "rewards": self.instance_rewards}, save_path, ) # https://github.com/Lightning-AI/lightning/issues/1406 self.trainer.should_stop = True
[docs]class EASEmb(EAS): """EAS with embedding adaptation""" def __init__( self, *args, **kwargs, ): if not kwargs.get("use_eas_embedding", False) or kwargs.get( "use_eas_layer", True ): log.warning( "Setting `use_eas_embedding` to True and `use_eas_layer` to False. Use EAS base class to override." ) kwargs["use_eas_embedding"] = True kwargs["use_eas_layer"] = False super(EASEmb, self).__init__(*args, **kwargs)
[docs]class EASLay(EAS): """EAS with layer adaptation""" def __init__( self, *args, **kwargs, ): if kwargs.get("use_eas_embedding", False) or not kwargs.get( "use_eas_layer", True ): log.warning( "Setting `use_eas_embedding` to True and `use_eas_layer` to False. Use EAS base class to override." ) kwargs["use_eas_embedding"] = False kwargs["use_eas_layer"] = True super(EASLay, self).__init__(*args, **kwargs)

© Copyright Federico Berto, Chuanbo Hua, Junyoung Park. Revision f4bc96ca.

Built with Sphinx using a theme provided by Read the Docs.