Shortcuts

Encoder Customization

In this notebook we will cover a tutorial for the flexible encoders!

Open In Colab

Installation

Uncomment the following line to install the package from PyPI. Remember to choose a GPU runtime for faster training!

Note: You may need to restart the runtime in Colab after this

[1]:
# !pip install rl4co[graph] # include torch-geometric

## NOTE: to install latest version from Github (may be unstable) install from source instead:
# !pip install git+https://github.com/ai4co/rl4co.git

Imports

[1]:
from rl4co.envs import CVRPEnv

from rl4co.models.zoo import AttentionModel
from rl4co.utils.trainer import RL4COTrainer

A default minimal training script

Here we use the CVRP environment and AM model as a minimal example of training script. By default, the AM is initialized with a Graph Attention Encoder, but we can change it to anything we want.

[3]:
# Init env, model, trainer
env = CVRPEnv(num_loc=20)

model = AttentionModel(
    env,
    baseline='rollout',
    train_data_size=100_000, # really small size for demo
    val_data_size=10_000
)

trainer = RL4COTrainer(
    max_epochs=3, # few epochs for demo
    accelerator='gpu',
    devices=1,
    logger=False,
)

# By default the AM uses the Graph Attention Encoder
print(f'Encoder: {model.policy.encoder._get_name()}')
/datasets/home/botu/mambaforge/envs/rl4co-new/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`.
/datasets/home/botu/mambaforge/envs/rl4co-new/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`.
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Encoder: GraphAttentionEncoder
[4]:
# Train the model
trainer.fit(model)
/datasets/home/botu/mambaforge/envs/rl4co-new/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:630: Checkpoint directory /datasets/home/botu/Dev/rl4co/notebooks/tutorials/checkpoints exists and is not empty.
val_file not set. Generating dataset instead
test_file not set. Generating dataset instead
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name     | Type                 | Params
--------------------------------------------------
0 | env      | CVRPEnv              | 0
1 | policy   | AttentionModelPolicy | 694 K
2 | baseline | WarmupBaseline       | 694 K
--------------------------------------------------
1.4 M     Trainable params
0         Non-trainable params
1.4 M     Total params
5.553     Total estimated model params size (MB)
/datasets/home/botu/mambaforge/envs/rl4co-new/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.
/datasets/home/botu/mambaforge/envs/rl4co-new/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.
`Trainer.fit` stopped: `max_epochs=3` reached.

Change the Encoder

In RL4CO, we provides two graph neural network encoders: Graph Convolutionsal Network (GCN) encoder and Message Passing Neural Network (MPNN) encoder. In this tutorial, we will show how to change the encoder.

Note: while we provide these examples, you can also implement your own encoder and use it in RL4CO! For instance, you may use different encoders (and decoders) to solve problems that require e.g. distance matrices as input

[5]:
# Before we init, we need to install the graph neural network dependencies
# !pip install rl4co[graph]
[7]:
# Init the model with different encoder
from rl4co.models.nn.graph.gcn import GCNEncoder
from rl4co.models.nn.graph.mpnn import MessagePassingEncoder

gcn_encoder = GCNEncoder(
    env_name='cvrp',
    embedding_dim=128,
    num_nodes=20,
    num_layers=3,
)

mpnn_encoder = MessagePassingEncoder(
    env_name='cvrp',
    embedding_dim=128,
    num_nodes=20,
    num_layers=3,
)

model = AttentionModel(
    env,
    baseline='rollout',
    train_data_size=100_000, # really small size for demo
    val_data_size=10_000,
    policy_kwargs={
        'encoder': gcn_encoder # gcn_encoder or mpnn_encoder
    }
)

trainer = RL4COTrainer(
    max_epochs=3, # few epochs for demo
    accelerator='gpu',
    devices=1,
    logger=False,
)
/datasets/home/botu/mambaforge/envs/rl4co-new/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`.
/datasets/home/botu/mambaforge/envs/rl4co-new/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`.
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[8]:
# Train the model
trainer.fit(model)
/datasets/home/botu/mambaforge/envs/rl4co-new/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:630: Checkpoint directory /datasets/home/botu/Dev/rl4co/notebooks/tutorials/checkpoints exists and is not empty.
val_file not set. Generating dataset instead
test_file not set. Generating dataset instead
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name     | Type                 | Params
--------------------------------------------------
0 | env      | CVRPEnv              | 0
1 | policy   | AttentionModelPolicy | 148 K
2 | baseline | WarmupBaseline       | 148 K
--------------------------------------------------
297 K     Trainable params
0         Non-trainable params
297 K     Total params
1.191     Total estimated model params size (MB)
/datasets/home/botu/mambaforge/envs/rl4co-new/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.
/datasets/home/botu/mambaforge/envs/rl4co-new/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.
`Trainer.fit` stopped: `max_epochs=3` reached.

Or you want to create your own encoder

If you want to create a new encoder, you may want to follow the following base class to create the encoder class with the folowing components:

  1. RL4CO provides the env_init_embedding method for each environment. You may want to use it to get the initial embedding of the environment.

  2. h and init_h as return hidden features have the shape ([batch_size], num_node, hidden_size)

  3. In RL4CO, we put the graph neural network encoders in the rl4co/models/nn/graph folder. You may want to put your customized encoder to the same folder. Feel free to send a PR to add your encoder to RL4CO!

[9]:
# Import necessary packages
import torch.nn as nn
from torch import Tensor
from tensordict import TensorDict
from typing import Tuple, Union
from rl4co.models.nn.env_embeddings import env_init_embedding


class BaseEncoder(nn.Module):
    def __init__(
            self,
            env_name: str,
            embedding_dim: int,
            init_embedding: nn.Module = None,
        ):
        super(BaseEncoder, self).__init__()
        self.env_name = env_name

        # Init embedding for each environment
        self.init_embedding = (
            env_init_embedding(self.env_name, {"embedding_dim": embedding_dim})
            if init_embedding is None
            else init_embedding
        )

    def forward(
        self, td: TensorDict, mask: Union[Tensor, None] = None
    ) -> Tuple[Tensor, Tensor]:
        """
        Args:
            td: Input TensorDict containing the environment state
            mask: Mask to apply to the attention

        Returns:
            h: Latent representation of the input
            init_h: Initial embedding of the input
        """
        init_h = self.init_embedding(td)
        h = None
        return h, init_h