Shortcuts

Search Methods

In this notebook, we will showcase how to use the Efficient Active Search (EAS) algorithm to find better solutions to existing problems!

Open In Colab

Installation

Uncomment the following line to install the package from PyPI. Remember to choose a GPU runtime for faster training!

Note: You may need to restart the runtime in Colab after this

[ ]:
# !pip install rl4co[graph] # include torch-geometric

## NOTE: to install latest version from Github (may be unstable) install from source instead:
# !pip install git+https://github.com/ai4co/rl4co.git

Imports

[1]:
%load_ext autoreload
%autoreload 2

import torch

from rl4co.envs import TSPEnv, CVRPEnv
from rl4co.models.zoo.am import AttentionModel
from rl4co.utils.trainer import RL4COTrainer
from rl4co.models.nn.utils import get_log_likelihood
from rl4co.models.zoo import EAS, EASLay, EASEmb, ActiveSearch

import logging
2023-08-22 16:29:17.903805: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-08-22 16:29:17.923169: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-08-22 16:29:18.249479: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[2]:
# Load from checkpoint; alternatively, simply instantiate a new model
checkpoint_path = "last.ckpt" # model trained for one epoch only just for showing the examples

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# load checkpoint
# checkpoint = torch.load(checkpoint_path)

model = AttentionModel.load_from_checkpoint(checkpoint_path, load_baseline=False)
policy = model.policy.to(device)
/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:196: UserWarning: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`.
  rank_zero_warn(
/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:196: UserWarning: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`.
  rank_zero_warn(
/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/core/saving.py:164: UserWarning: Found keys that are not in the model state dict but in the checkpoint: ['baseline.baseline.model.encoder.init_embedding.init_embed.weight', 'baseline.baseline.model.encoder.init_embedding.init_embed.bias', 'baseline.baseline.model.encoder.net.layers.0.0.module.Wqkv.weight', 'baseline.baseline.model.encoder.net.layers.0.0.module.Wqkv.bias', 'baseline.baseline.model.encoder.net.layers.0.0.module.out_proj.weight', 'baseline.baseline.model.encoder.net.layers.0.0.module.out_proj.bias', 'baseline.baseline.model.encoder.net.layers.0.1.normalizer.weight', 'baseline.baseline.model.encoder.net.layers.0.1.normalizer.bias', 'baseline.baseline.model.encoder.net.layers.0.1.normalizer.running_mean', 'baseline.baseline.model.encoder.net.layers.0.1.normalizer.running_var', 'baseline.baseline.model.encoder.net.layers.0.1.normalizer.num_batches_tracked', 'baseline.baseline.model.encoder.net.layers.0.2.module.0.weight', 'baseline.baseline.model.encoder.net.layers.0.2.module.0.bias', 'baseline.baseline.model.encoder.net.layers.0.2.module.2.weight', 'baseline.baseline.model.encoder.net.layers.0.2.module.2.bias', 'baseline.baseline.model.encoder.net.layers.0.3.normalizer.weight', 'baseline.baseline.model.encoder.net.layers.0.3.normalizer.bias', 'baseline.baseline.model.encoder.net.layers.0.3.normalizer.running_mean', 'baseline.baseline.model.encoder.net.layers.0.3.normalizer.running_var', 'baseline.baseline.model.encoder.net.layers.0.3.normalizer.num_batches_tracked', 'baseline.baseline.model.encoder.net.layers.1.0.module.Wqkv.weight', 'baseline.baseline.model.encoder.net.layers.1.0.module.Wqkv.bias', 'baseline.baseline.model.encoder.net.layers.1.0.module.out_proj.weight', 'baseline.baseline.model.encoder.net.layers.1.0.module.out_proj.bias', 'baseline.baseline.model.encoder.net.layers.1.1.normalizer.weight', 'baseline.baseline.model.encoder.net.layers.1.1.normalizer.bias', 'baseline.baseline.model.encoder.net.layers.1.1.normalizer.running_mean', 'baseline.baseline.model.encoder.net.layers.1.1.normalizer.running_var', 'baseline.baseline.model.encoder.net.layers.1.1.normalizer.num_batches_tracked', 'baseline.baseline.model.encoder.net.layers.1.2.module.0.weight', 'baseline.baseline.model.encoder.net.layers.1.2.module.0.bias', 'baseline.baseline.model.encoder.net.layers.1.2.module.2.weight', 'baseline.baseline.model.encoder.net.layers.1.2.module.2.bias', 'baseline.baseline.model.encoder.net.layers.1.3.normalizer.weight', 'baseline.baseline.model.encoder.net.layers.1.3.normalizer.bias', 'baseline.baseline.model.encoder.net.layers.1.3.normalizer.running_mean', 'baseline.baseline.model.encoder.net.layers.1.3.normalizer.running_var', 'baseline.baseline.model.encoder.net.layers.1.3.normalizer.num_batches_tracked', 'baseline.baseline.model.encoder.net.layers.2.0.module.Wqkv.weight', 'baseline.baseline.model.encoder.net.layers.2.0.module.Wqkv.bias', 'baseline.baseline.model.encoder.net.layers.2.0.module.out_proj.weight', 'baseline.baseline.model.encoder.net.layers.2.0.module.out_proj.bias', 'baseline.baseline.model.encoder.net.layers.2.1.normalizer.weight', 'baseline.baseline.model.encoder.net.layers.2.1.normalizer.bias', 'baseline.baseline.model.encoder.net.layers.2.1.normalizer.running_mean', 'baseline.baseline.model.encoder.net.layers.2.1.normalizer.running_var', 'baseline.baseline.model.encoder.net.layers.2.1.normalizer.num_batches_tracked', 'baseline.baseline.model.encoder.net.layers.2.2.module.0.weight', 'baseline.baseline.model.encoder.net.layers.2.2.module.0.bias', 'baseline.baseline.model.encoder.net.layers.2.2.module.2.weight', 'baseline.baseline.model.encoder.net.layers.2.2.module.2.bias', 'baseline.baseline.model.encoder.net.layers.2.3.normalizer.weight', 'baseline.baseline.model.encoder.net.layers.2.3.normalizer.bias', 'baseline.baseline.model.encoder.net.layers.2.3.normalizer.running_mean', 'baseline.baseline.model.encoder.net.layers.2.3.normalizer.running_var', 'baseline.baseline.model.encoder.net.layers.2.3.normalizer.num_batches_tracked', 'baseline.baseline.model.decoder.context_embedding.W_placeholder', 'baseline.baseline.model.decoder.context_embedding.project_context.weight', 'baseline.baseline.model.decoder.project_node_embeddings.weight', 'baseline.baseline.model.decoder.project_fixed_context.weight', 'baseline.baseline.model.decoder.logit_attention.project_out.weight']
  rank_zero_warn(
[3]:
# env = CVRPEnv(num_loc=50)
# policy = AttentionModel(env).policy.to(device)

env = TSPEnv(num_loc=50)

td = env.reset(batch_size=3).to(device)

out = policy(td, return_actions=True)
[4]:
env.render(td.cpu(), out["actions"].cpu())
../../_images/_collections_tutorials_4-search-methods_7_0.png

EAS

We perform few iterations of EASLay for demonstration

[5]:
logging.basicConfig(level=logging.DEBUG)

env.num_loc = 200

dataset = env.dataset(batch_size=[2])
# eas_model = EASEmb(env, policy, dataset, batch_size=2, max_iters=20, save_path="eas_sols.pt") # alternative
eas_model = EASLay(env, policy, dataset, batch_size=2, max_iters=20, save_path="eas_sols.pt")

eas_model.setup()
/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:196: UserWarning: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`.
  rank_zero_warn(
/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:196: UserWarning: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`.
  rank_zero_warn(
INFO:rl4co.models.rl.common.base:No metrics specified, using default
INFO:rl4co.models.zoo.eas.search:Setting up Efficient Active Search (EAS) with:
- EAS Embedding: False
- EAS Layer: True

[6]:
# Plot initial solution
td_dataset = next(iter(eas_model.train_dataloader()))
td_dataset = env.reset(td_dataset).to(device)
out = policy(td_dataset, return_actions=True)

env.render(td_dataset.cpu(), out["actions"].cpu())
INFO:rl4co.models.zoo.common.autoregressive.policy:Instantiated environment not provided; instantiating tsp
../../_images/_collections_tutorials_4-search-methods_10_1.png

Load actions

[10]:
# Load
actions = torch.load("eas_sols.pt")["solutions"][0].cpu()
actions = actions[:torch.count_nonzero(actions, dim=-1)] # remove trailing zeros
state = td_dataset.cpu()[0]

env.render(state, actions)
../../_images/_collections_tutorials_4-search-methods_14_0.png

Even with few iterations, the search method can clearly find better solutions than the initial ones!