Training: Checkpoints, Logging, and Callbacks¶
In this notebook we will cover a quickstart training of the Split Delivery Vehicle Routing Problem (SDVRP), with some additional comments along the way. The SDVRP is a variant of the VRP where a vehicle can deliver a part of the demand of a customer and return later to deliver the rest of the demand.
Installation¶
Uncomment the following line to install the package from PyPI. Remember to choose a GPU runtime for faster training!
Note: You may need to restart the runtime in Colab after this
[1]:
# !pip install rl4co
## NOTE: to install latest version from Github (may be unstable) install from source instead:
# !pip install git+https://github.com/kaist-silab/rl4co.git
Imports¶
[2]:
%load_ext autoreload
%autoreload 2
import torch
from lightning.pytorch.callbacks import ModelCheckpoint, RichModelSummary
from rl4co.envs import SDVRPEnv
from rl4co.models.zoo import AttentionModel
from rl4co.utils.trainer import RL4COTrainer
2023-07-23 00:12:34.022428: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-07-23 00:12:34.042563: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-07-23 00:12:34.372727: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/torchrl/__init__.py:26: UserWarning: failed to set start method to spawn, and current start method for mp is fork.
warn(
Main Setup¶
Environment, Model and LitModule¶
[3]:
# RL4CO env based on TorchRL
env = SDVRPEnv(num_loc=20)
# Model: default is AM with REINFORCE and greedy rollout baseline
model = AttentionModel(env,
baseline='rollout',
train_data_size=100_000, # really small size for demo
val_data_size=10_000)
trainer = RL4COTrainer(max_epochs=3)
/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:196: UserWarning: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`.
rank_zero_warn(
/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:196: UserWarning: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`.
rank_zero_warn(
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Test greedy rollout with untrained model and plot¶
[4]:
# Greedy rollouts over untrained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
td_init = env.reset(batch_size=[3]).to(device)
model = model.to(device)
out = model(td_init, phase="test", decode_type="greedy", return_actions=True)
# Plotting
print(f"Tour lengths: {[f'{-r.item():.2f}' for r in out['reward']]}")
for td, actions in zip(td_init, out['actions'].cpu()):
env.render(td, actions)
Tour lengths: ['14.51', '16.32', '10.78']
Training¶
Callbacks¶
Here we set up a checkpoint callback to save the best model and another callback for demonstration (nice progress bar). You may check other callbacks here
[5]:
# Checkpointing callback: save models when validation reward improves
checkpoint_callback = ModelCheckpoint( dirpath="checkpoints", # save to checkpoints/
filename="epoch_{epoch:03d}", # save as epoch_XXX.ckpt
save_top_k=1, # save only the best model
save_last=True, # save the last model
monitor="val/reward", # monitor validation reward
mode="max") # maximize validation reward
# Print model summary
rich_model_summary = RichModelSummary(max_depth=3)
# Callbacks list
callbacks = [checkpoint_callback, rich_model_summary]
Logging¶
Here we will use Wandb. You may comment below lines if you don’t want to use it. You may check other loggers here
We make sure we’re logged into W&B so that our experiments can be associated with our account. You may comment the below line if you don’t want to use it.
[6]:
# import wandb
# wandb.login()
[7]:
## Comment following two lines if you don't want logging
from lightning.pytorch.loggers import WandbLogger
logger = WandbLogger(project="rl4co", name="sdvrp-am")
## Keep below if you don't want logging
# logger = None
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: silab-kaist. Use `wandb login --relogin` to force relogin
./wandb/run-20230723_001237-9t1v3c73Trainer¶
The RL4CO trainer is a wrapper around PyTorch Lightning’s Trainer class which adds some functionality and more efficient defaults
The Trainer handles the logging, checkpointing and more for you.
[8]:
from rl4co.utils.trainer import RL4COTrainer
trainer = RL4COTrainer(
max_epochs=2,
accelerator="gpu",
logger=logger,
callbacks=callbacks,
)
Using 16bit Automatic Mixed Precision (AMP)
Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Fit the model¶
[9]:
trainer.fit(model)
/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:615: UserWarning: Checkpoint directory /home/botu/Dev/rl4co/notebooks/examples/checkpoints exists and is not empty.
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
val_file not set. Generating dataset instead
test_file not set. Generating dataset instead
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓ ┃ ┃ Name ┃ Type ┃ Params ┃ ┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩ │ 0 │ env │ SDVRPEnv │ 0 │ │ 1 │ policy │ AttentionModelPolicy │ 694 K │ │ 2 │ policy.encoder │ GraphAttentionEncoder │ 595 K │ │ 3 │ policy.encoder.init_embedding │ VRPInitEmbedding │ 896 │ │ 4 │ policy.encoder.net │ GraphAttentionNetwork │ 594 K │ │ 5 │ policy.decoder │ AutoregressiveDecoder │ 98.8 K │ │ 6 │ policy.decoder.context_embedding │ VRPContext │ 16.5 K │ │ 7 │ policy.decoder.dynamic_embedding │ SDVRPDynamicEmbedding │ 384 │ │ 8 │ policy.decoder.project_node_embeddings │ Linear │ 49.2 K │ │ 9 │ policy.decoder.project_fixed_context │ Linear │ 16.4 K │ │ 10 │ policy.decoder.logit_attention │ LogitAttention │ 16.4 K │ │ 11 │ baseline │ WarmupBaseline │ 694 K │ │ 12 │ baseline.baseline │ RolloutBaseline │ 694 K │ │ 13 │ baseline.baseline.model │ AttentionModelPolicy │ 694 K │ │ 14 │ baseline.warmup_baseline │ ExponentialBaseline │ 0 │ └────┴────────────────────────────────────────┴───────────────────────┴────────┘
Trainable params: 1.4 M Non-trainable params: 0 Total params: 1.4 M Total estimated model params size (MB): 5
/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:432: PossibleUserWarning: The dataloader, val_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 32 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
rank_zero_warn(
/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:432: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 32 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
rank_zero_warn(
`Trainer.fit` stopped: `max_epochs=2` reached.
Testing¶
Plotting¶
Here we plot the solution (greedy rollout) of the trained model to the initial problem
[10]:
# Greedy rollouts over trained model (same states as previous plot)
model = model.to(device)
out = model(td_init, phase="test", decode_type="greedy", return_actions=True)
# Plotting
print(f"Tour lengths: {[f'{-r.item():.2f}' for r in out['reward']]}")
for td, actions in zip(td_init, out['actions'].cpu()):
env.render(td, actions)
Tour lengths: ['6.73', '6.81', '8.23']
Test function¶
By default, the dataset is generated or loaded by the environment. You may load a dataset by setting test_file during the env config:
env = SDVRPEnv(
...
test_file="path/to/test/file"
)
In this case, we test directly on the generated test dataset
[11]:
trainer.test(model)
val_file not set. Generating dataset instead
test_file not set. Generating dataset instead
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:432: PossibleUserWarning: The dataloader, test_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 32 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
rank_zero_warn(
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ DataLoader 0 ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ test/reward │ -7.224186897277832 │ └───────────────────────────┴───────────────────────────┘
[11]:
[{'test/reward': -7.224186897277832}]
Test generalization to new dataset¶
Here we can load a new dataset (with 50 nodes) and test the trained model on it
[12]:
# Test generalization to 50 nodes (not going to be great due to few epochs, but hey)
env = SDVRPEnv(num_loc=50)
# Generate data (100) and set as test dataset
new_dataset = env.dataset(50)
dataloader = model._dataloader(new_dataset, batch_size=100)
Plotting generalization¶
[13]:
# Greedy rollouts over trained model (same states as previous plot, with 20 nodes)
model = model.to(device)
init_states = next(iter(dataloader))[:3]
td_init_generalization = env.reset(init_states).to(device)
out = model(td_init_generalization, phase="test", decode_type="greedy", return_actions=True)
# Plotting
print(f"Tour lengths: {[f'{-r.item():.2f}' for r in out['reward']]}")
for td, actions in zip(td_init_generalization, out['actions'].cpu()):
env.render(td, actions)
Tour lengths: ['11.42', '15.26', '15.04']
Loading model¶
Thanks to PyTorch Lightning,we can easily save and load a model to and from a checkpoint! This is declared in the Trainer using the model checkpoint callback. For example, we can load the last model via the last.ckpt file located in the folder we specified in the Trainer.
Checkpointing¶
[4]:
from rl4co.envs import TSPEnv
from rl4co.models import AttentionModel
from rl4co.utils import RL4COTrainer
# Environment, Model, and Lightning Module
env = TSPEnv(num_loc=20)
model = AttentionModel(env,
baseline="rollout",
train_data_size=100_000,
test_data_size=10_000,
optimizer_kwargs={'lr': 1e-4}
)
# Trainer
trainer = RL4COTrainer(max_epochs=3)
# Fit the model
trainer.fit(model)
# Test the model
trainer.test(model)
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
val_file not set. Generating dataset instead
test_file not set. Generating dataset instead
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
--------------------------------------------------
0 | env | TSPEnv | 0
1 | policy | AttentionModelPolicy | 710 K
2 | baseline | WarmupBaseline | 710 K
--------------------------------------------------
1.4 M Trainable params
0 Non-trainable params
1.4 M Total params
5.681 Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=3` reached.
val_file not set. Generating dataset instead
test_file not set. Generating dataset instead
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ DataLoader 0 ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ test/reward │ -4.025482177734375 │
└───────────────────────────┴───────────────────────────┘
[4]:
[{'test/reward': -4.025482177734375}]
[14]:
%load_ext autoreload
%autoreload 2
from rl4co.models.zoo import AttentionModel
# Note that by default, Lightning will call checkpoints from newer runs with "-v{version}" suffix
# unless you specify the checkpoint path explicitly
new_model_checkpoint = AttentionModel.load_from_checkpoint("checkpoints/last.ckpt", strict=False)
/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:196: UserWarning: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`.
rank_zero_warn(
/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:196: UserWarning: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`.
rank_zero_warn(
/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/core/saving.py:164: UserWarning: Found keys that are not in the model state dict but in the checkpoint: ['baseline.baseline.model.encoder.init_embedding.init_embed.weight', 'baseline.baseline.model.encoder.init_embedding.init_embed.bias', 'baseline.baseline.model.encoder.init_embedding.init_embed_depot.weight', 'baseline.baseline.model.encoder.init_embedding.init_embed_depot.bias', 'baseline.baseline.model.encoder.net.layers.0.0.module.Wqkv.weight', 'baseline.baseline.model.encoder.net.layers.0.0.module.Wqkv.bias', 'baseline.baseline.model.encoder.net.layers.0.0.module.out_proj.weight', 'baseline.baseline.model.encoder.net.layers.0.0.module.out_proj.bias', 'baseline.baseline.model.encoder.net.layers.0.1.normalizer.weight', 'baseline.baseline.model.encoder.net.layers.0.1.normalizer.bias', 'baseline.baseline.model.encoder.net.layers.0.1.normalizer.running_mean', 'baseline.baseline.model.encoder.net.layers.0.1.normalizer.running_var', 'baseline.baseline.model.encoder.net.layers.0.1.normalizer.num_batches_tracked', 'baseline.baseline.model.encoder.net.layers.0.2.module.0.weight', 'baseline.baseline.model.encoder.net.layers.0.2.module.0.bias', 'baseline.baseline.model.encoder.net.layers.0.2.module.2.weight', 'baseline.baseline.model.encoder.net.layers.0.2.module.2.bias', 'baseline.baseline.model.encoder.net.layers.0.3.normalizer.weight', 'baseline.baseline.model.encoder.net.layers.0.3.normalizer.bias', 'baseline.baseline.model.encoder.net.layers.0.3.normalizer.running_mean', 'baseline.baseline.model.encoder.net.layers.0.3.normalizer.running_var', 'baseline.baseline.model.encoder.net.layers.0.3.normalizer.num_batches_tracked', 'baseline.baseline.model.encoder.net.layers.1.0.module.Wqkv.weight', 'baseline.baseline.model.encoder.net.layers.1.0.module.Wqkv.bias', 'baseline.baseline.model.encoder.net.layers.1.0.module.out_proj.weight', 'baseline.baseline.model.encoder.net.layers.1.0.module.out_proj.bias', 'baseline.baseline.model.encoder.net.layers.1.1.normalizer.weight', 'baseline.baseline.model.encoder.net.layers.1.1.normalizer.bias', 'baseline.baseline.model.encoder.net.layers.1.1.normalizer.running_mean', 'baseline.baseline.model.encoder.net.layers.1.1.normalizer.running_var', 'baseline.baseline.model.encoder.net.layers.1.1.normalizer.num_batches_tracked', 'baseline.baseline.model.encoder.net.layers.1.2.module.0.weight', 'baseline.baseline.model.encoder.net.layers.1.2.module.0.bias', 'baseline.baseline.model.encoder.net.layers.1.2.module.2.weight', 'baseline.baseline.model.encoder.net.layers.1.2.module.2.bias', 'baseline.baseline.model.encoder.net.layers.1.3.normalizer.weight', 'baseline.baseline.model.encoder.net.layers.1.3.normalizer.bias', 'baseline.baseline.model.encoder.net.layers.1.3.normalizer.running_mean', 'baseline.baseline.model.encoder.net.layers.1.3.normalizer.running_var', 'baseline.baseline.model.encoder.net.layers.1.3.normalizer.num_batches_tracked', 'baseline.baseline.model.encoder.net.layers.2.0.module.Wqkv.weight', 'baseline.baseline.model.encoder.net.layers.2.0.module.Wqkv.bias', 'baseline.baseline.model.encoder.net.layers.2.0.module.out_proj.weight', 'baseline.baseline.model.encoder.net.layers.2.0.module.out_proj.bias', 'baseline.baseline.model.encoder.net.layers.2.1.normalizer.weight', 'baseline.baseline.model.encoder.net.layers.2.1.normalizer.bias', 'baseline.baseline.model.encoder.net.layers.2.1.normalizer.running_mean', 'baseline.baseline.model.encoder.net.layers.2.1.normalizer.running_var', 'baseline.baseline.model.encoder.net.layers.2.1.normalizer.num_batches_tracked', 'baseline.baseline.model.encoder.net.layers.2.2.module.0.weight', 'baseline.baseline.model.encoder.net.layers.2.2.module.0.bias', 'baseline.baseline.model.encoder.net.layers.2.2.module.2.weight', 'baseline.baseline.model.encoder.net.layers.2.2.module.2.bias', 'baseline.baseline.model.encoder.net.layers.2.3.normalizer.weight', 'baseline.baseline.model.encoder.net.layers.2.3.normalizer.bias', 'baseline.baseline.model.encoder.net.layers.2.3.normalizer.running_mean', 'baseline.baseline.model.encoder.net.layers.2.3.normalizer.running_var', 'baseline.baseline.model.encoder.net.layers.2.3.normalizer.num_batches_tracked', 'baseline.baseline.model.decoder.context_embedding.project_context.weight', 'baseline.baseline.model.decoder.dynamic_embedding.projection.weight', 'baseline.baseline.model.decoder.project_node_embeddings.weight', 'baseline.baseline.model.decoder.project_fixed_context.weight', 'baseline.baseline.model.decoder.logit_attention.project_out.weight']
rank_zero_warn(
val_file not set. Generating dataset instead
test_file not set. Generating dataset instead
The autoreload extension is already loaded. To reload it, use:
%reload_ext autoreload
Now we can load both the model and environment from the checkpoint!
[15]:
# Greedy rollouts over trained model (same states as previous plot, with 20 nodes)
model = new_model_checkpoint.to(device)
env = new_model_checkpoint.env.to(device)
out = model(td_init, phase="test", decode_type="greedy", return_actions=True)
# Plotting
print(f"Tour lengths: {[f'{-r.item():.2f}' for r in out['reward']]}")
for td, actions in zip(td_init, out['actions'].cpu()):
env.render(td, actions)
Tour lengths: ['8.22', '9.06', '8.23']
Additional resources¶
Documentation | Getting Started | Usage | Contributing | Paper | Citation
Have feedback about this notebook? Feel free to contribute by either opening an issue or a pull request! ;)