Encoder Customization¶
In this notebook we will cover a tutorial for the flaxible encoders!
Installation¶
Uncomment the following line to install the package from PyPI. Remember to choose a GPU runtime for faster training!
Note: You may need to restart the runtime in Colab after this
[1]:
# !pip install rl4co[graph] # include torch-geometric
## NOTE: to install latest version from Github (may be unstable) install from source instead:
# !pip install git+https://github.com/kaist-silab/rl4co.git
Imports¶
[2]:
# TODO: Temp
import sys; sys.path.append('../..')
%load_ext autoreload
%autoreload 2
from rl4co.envs import CVRPEnv
from rl4co.models.zoo import AttentionModel
from rl4co.utils.trainer import RL4COTrainer
2023-08-22 18:18:55.097860: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-08-22 18:18:55.116842: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-08-22 18:18:55.438994: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
A default minimal training script¶
Here we use the CVRP environment and AM model as a minimal example of training script. By default, the AM is initialized with a Graph Attention Encoder, but we can change it to anything we want.
[3]:
# Init env, model, trainer
env = CVRPEnv(num_loc=20)
model = AttentionModel(
env,
baseline='rollout',
train_data_size=100_000, # really small size for demo
val_data_size=10_000
)
trainer = RL4COTrainer(
max_epochs=3, # few epochs for demo
accelerator='gpu',
logger=False,
)
# By default the AM uses the Graph Attention Encoder
print(f'Encoder: {model.policy.encoder._get_name()}')
/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:196: UserWarning: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`.
rank_zero_warn(
/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:196: UserWarning: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`.
rank_zero_warn(
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Encoder: GraphAttentionEncoder
[4]:
# Train the model
trainer.fit(model)
val_file not set. Generating dataset instead
test_file not set. Generating dataset instead
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
--------------------------------------------------
0 | env | CVRPEnv | 0
1 | policy | AttentionModelPolicy | 694 K
2 | baseline | WarmupBaseline | 694 K
--------------------------------------------------
1.4 M Trainable params
0 Non-trainable params
1.4 M Total params
5.553 Total estimated model params size (MB)
/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:432: PossibleUserWarning: The dataloader, val_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 32 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
rank_zero_warn(
/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:432: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 32 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
rank_zero_warn(
`Trainer.fit` stopped: `max_epochs=3` reached.
Change the Encoder¶
In RL4CO, we provides two graph neural network encoders: Graph Convolutionsal Network (GCN) encoder and Message Passing Neural Network (MPNN) encoder. In this tutorial, we will show how to change the encoder.
Note: while we provide these examples, you can also implement your own encoder and use it in RL4CO!
[5]:
# Before we init, we need to install the graph neural network dependencies
# !pip install rl4co[graph]
[6]:
# Init the model with different encoder
from rl4co.models.nn.graph.gcn import GCNEncoder
from rl4co.models.nn.graph.mpnn import MessagePassingEncoder
gcn_encoder = GCNEncoder(
env_name='cvrp',
embedding_dim=128,
num_nodes=20,
num_layers=3,
)
mpnn_encoder = MessagePassingEncoder(
env_name='cvrp',
embedding_dim=128,
num_nodes=20,
num_layers=3,
)
model = AttentionModel(
env,
baseline='rollout',
train_data_size=100_000, # really small size for demo
val_data_size=10_000,
policy_kwargs={
'encoder': gcn_encoder # gcn_encoder or mpnn_encoder
}
)
trainer = RL4COTrainer(
max_epochs=3, # few epochs for demo
accelerator='gpu',
logger=False,
)
/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:196: UserWarning: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`.
rank_zero_warn(
/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:196: UserWarning: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`.
rank_zero_warn(
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[7]:
# Train the model
trainer.fit(model)
/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:615: UserWarning: Checkpoint directory /home/botu/Dev/rl4co/notebooks/tutorials/checkpoints exists and is not empty.
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
val_file not set. Generating dataset instead
test_file not set. Generating dataset instead
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
--------------------------------------------------
0 | env | CVRPEnv | 0
1 | policy | AttentionModelPolicy | 148 K
2 | baseline | WarmupBaseline | 148 K
--------------------------------------------------
297 K Trainable params
0 Non-trainable params
297 K Total params
1.191 Total estimated model params size (MB)
/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:432: PossibleUserWarning: The dataloader, val_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 32 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
rank_zero_warn(
/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:432: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 32 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
rank_zero_warn(
`Trainer.fit` stopped: `max_epochs=3` reached.
Or you want to create your own encoder¶
If you want to create a new encoder, you may want to follow the following base class to create the encoder class with the folowing components:
RL4CO provides the
env_init_embeddingmethod for each environment. You may want to use it to get the initial embedding of the environment.handinit_has return hidden features have the shape([batch_size], num_node, hidden_size)In RL4CO, we put the graph neural network encoders in the
rl4co/models/nn/graphfolder. You may want to put your customized encoder to the same folder. Feel free to send a PR to add your encoder to RL4CO!
[8]:
# Import necessary packages
import torch.nn as nn
from torch import Tensor
from tensordict import TensorDict
from typing import Tuple, Union
from rl4co.models.nn.env_embeddings import env_init_embedding
class BaseEncoder(nn.Module):
def __init__(
self,
env_name: str,
embedding_dim: int,
init_embedding: nn.Module = None,
):
super(BaseEncoder, self).__init__()
self.env_name = env_name
# Init embedding for each environment
self.init_embedding = (
env_init_embedding(self.env_name, {"embedding_dim": embedding_dim})
if init_embedding is None
else init_embedding
)
def forward(
self, td: TensorDict, mask: Union[Tensor, None] = None
) -> Tuple[Tensor, Tensor]:
"""
Args:
td: Input TensorDict containing the environment state
mask: Mask to apply to the attention
Returns:
h: Latent representation of the input
init_h: Initial embedding of the input
"""
init_h = self.init_embedding(td)
h = None
return h, init_h