Shortcuts

New Environment: Creating and Modeling

In this notebook, we will show how to extend RL4CO to solve new problems from zero to hero! 🚀

Open In Colab

Contents

  1. Environment

  2. Modeling

  3. Training

  4. Evaluation

Problem: MDPDP

Let us consider a new complex problem: the Open Multi-Depot Pickup and Delivery Problem (MDPDP). - Objective: find set of routes for a fleet of vehicles to pickup and deliver a set of orders such that the total routes are minimized. Vehicles start from different depots and the problem is open since they do not need to return to the depot. - Constraints: - Maximum number of vehicles (i.e. agents) is fixed - Each vehicle has a maximum capacity - Pickup and delivery pairs have to be served by the same vehicle with a precendence constraint (no pickup before delivery)

The MDPDP is a complex, realistic problem that can be solved with RL4CO. For instance, this problem can be found in the context of ride-sharing, where a fleet of vehicles (e.g. taxis) have to pickup and deliver passengers to their destinations or in the context of food delivery, where a fleet of vehicles (e.g. riders) have to pickup and deliver orders to customers.

Installation

[1]:
## Uncomment the following line to install the package from PyPI
## You may need to restart the runtime in Colab after this
## Remember to choose a GPU runtime for faster training!

# !pip install rl4co

Imports

[2]:


%load_ext autoreload %autoreload 2 import sys; sys.path.append(2*"../") from typing import Optional from einops import rearrange from matplotlib.axes import Axes import torch import torch.nn as nn from tensordict.tensordict import TensorDict from torchrl.data import ( BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec, ) from rl4co.envs.common.base import RL4COEnvBase from rl4co.utils.ops import gather_by_index, get_tour_length from rl4co.models.nn.utils import rollout, random_policy from rl4co.models.zoo import AttentionModel, AutoregressivePolicy from rl4co.utils.trainer import RL4COTrainer
2023-08-17 10:54:24.930233: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-08-17 10:54:24.950591: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-08-17 10:54:25.406396: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT

Environment Creation

We will base environment creation on the RL4COEnvBase class, which is based on TorchRL:

[3]:
?? RL4COEnvBase
Init signature:
 RL4COEnvBase(
    *,
    data_dir: str = 'data/',
    train_file: str = None,
    val_file: str = None,
    test_file: str = None,
    check_solution: bool = True,
    seed: int = None,
    device: str = 'cpu',
    **kwargs,
)
Source:
class RL4COEnvBase(EnvBase):
    """Base class for RL4CO environments based on TorchRL EnvBase

    Args:
        data_dir: Root directory for the dataset
        train_file: Name of the training file
        val_file: Name of the validation file
        test_file: Name of the test file
        check_solution: Whether to check the validity of the solution at the end of the episode
        seed: Seed for the environment
        device: Device to use. Generally, no need to set as tensors are updated on the fly
    """

    batch_locked = False

    def __init__(
        self,
        *,
        data_dir: str = "data/",
        train_file: str = None,
        val_file: str = None,
        test_file: str = None,
        check_solution: bool = True,
        seed: int = None,
        device: str = "cpu",
        **kwargs,
    ):
        super().__init__(device=device, batch_size=[])
        self.data_dir = data_dir
        self.train_file = pjoin(data_dir, train_file) if train_file is not None else None
        self.val_file = pjoin(data_dir, val_file) if val_file is not None else None
        self.test_file = pjoin(data_dir, test_file) if test_file is not None else None
        self.check_solution = check_solution
        if seed is None:
            seed = torch.empty((), dtype=torch.int64).random_().item()
        self.set_seed(seed)

    def _step(self, td: TensorDict) -> TensorDict:
        """Step function to call at each step of the episode containing an action.
        Gives the next observation, reward, done
        """
        raise NotImplementedError

    def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict:
        """Reset function to call at the beginning of each episode"""
        raise NotImplementedError

    def _make_spec(self, td_params: TensorDict = None):
        """Make the specifications of the environment (observation, action, reward, done)"""
        raise NotImplementedError

    def get_reward(self, td, actions) -> TensorDict:
        """Function to compute the reward. Can be called by the agent to compute the reward of the current state
        This is faster than calling step() and getting the reward from the returned TensorDict at each time for CO tasks
        """
        raise NotImplementedError

    def get_action_mask(self, td: TensorDict) -> torch.Tensor:
        """Function to compute the action mask (feasible actions) for the current state
        Action mask is 1 if the action is feasible, 0 otherwise
        """
        raise NotImplementedError

    def check_solution_validity(self, td, actions) -> TensorDict:
        """Function to check whether the solution is valid. Can be called by the agent to check the validity of the current state
        This is called with the full solution (i.e. all actions) at the end of the episode
        """
        raise NotImplementedError

    def dataset(self, batch_size=[], phase="train", filename=None):
        """Return a dataset of observations
        Generates the dataset if it does not exist, otherwise loads it from file
        """
        if filename is not None:
            log.info(f"Overriding dataset filename from {filename}")
        f = getattr(self, f"{phase}_file") if filename is None else filename
        if f is None:
            if phase != "train":
                log.warning(f"{phase}_file not set. Generating dataset instead")
            td = self.generate_data(batch_size)
        else:
            log.info(f"Loading {phase} dataset from {f}")
            if phase == "train":
                log.warning(
                    "Loading training dataset from file. This may not be desired in RL since "
                    "the dataset is fixed and the agent will not be able to explore new states"
                )
            try:
                td = self.load_data(f, batch_size)
            except FileNotFoundError:
                log.error(
                    f"Provided file name {f} not found. Make sure to provide a file in the right path first or "
                    f"unset {phase}_file to generate data automatically instead"
                )
                td = self.generate_data(batch_size)

        return TensorDictDataset(td)

    def generate_data(self, batch_size):
        """Dataset generation"""
        raise NotImplementedError

    def transform(self):
        """Used for converting TensorDict variables (such as with torch.cat) efficiently
        https://pytorch.org/rl/reference/generated/torchrl.envs.transforms.Transform.html
        By default, we do not need to transform the environment since we use specific embeddings
        """
        return self

    def render(self, *args, **kwargs):
        """Render the environment"""
        raise NotImplementedError

    @staticmethod
    def load_data(fpath, batch_size=[]):
        """Dataset loading from file"""
        return load_npz_to_tensordict(fpath)

    def _set_seed(self, seed: Optional[int]):
        """Set the seed for the environment"""
        rng = torch.manual_seed(seed)
        self.rng = rng

    def __getstate__(self):
        """Return the state of the environment. By default, we want to avoid pickling
        the random number generator directly as it is not allowed by `deepcopy`
        """
        state = self.__dict__.copy()
        state["rng"] = state["rng"].get_state()
        return state

    def __setstate__(self, state):
        """Set the state of the environment. By default, we want to avoid pickling
        the random number generator directly as it is not allowed by `deepcopy`
        """
        self.__dict__.update(state)
        self.rng = torch.manual_seed(0)
        self.rng.set_state(state["rng"])
File:           ~/Dev/rl4co-rebuttal/rl4co/envs/common/base.py
Type:           ABCMeta
Subclasses:     ATSPEnv, CVRPEnv, DPPEnv, FFSPEnv, MTSPEnv, OPEnv, PCTSPEnv, PDPEnv, TSPEnv

Reset

The _reset function is used to initialize the environment to an initial state. It returns a TensorDict of the initial state.

[4]:
def _reset(
    self, td: Optional[TensorDict] = None, batch_size: Optional[int] = None
) -> TensorDict:
    if batch_size is None:
        batch_size = self.batch_size if td is None else td.batch_size

    # Data generation: if not provided, generate a new batch of data
    if td is None or td.is_empty():
        td = self.generate_data(batch_size=batch_size)

    self.device = td.device

    td_reset = TensorDict(
        {
            "pickup_locs": td["pickup_locs"],
            "delivery_locs": td["delivery_locs"],
            "vehicle_locs": td["vehicle_locs"],
            "pickup_visited": torch.zeros(
                *td["pickup_locs"].shape[:-1], dtype=torch.bool, device=td.device
            ),
            "delivery_visited": torch.zeros(
                *td["delivery_locs"].shape[:-1], dtype=torch.bool, device=td.device
            ),
            "to_deliver": torch.zeros(
                *td["delivery_locs"].shape[:-1], dtype=torch.bool, device=td.device
            ),  # whether the delivery is to be delivered (i.e., the corresponding pickup is visited)
            "i": torch.zeros(batch_size, dtype=torch.int64, device=td.device),
            "current_vehicle_idx": torch.zeros(
                *batch_size, dtype=torch.long, device=td.device
            ),  # used to denote vehicle index
            "current_vehicle_loads": torch.zeros(
                *batch_size, dtype=torch.int64, device=td.device
            ),  # used to denote current vehicle loads
            "current_vehicle_max_loads": torch.zeros(
                *batch_size, dtype=torch.int64, device=td.device
            ),  # used to denote the maximal vehicle loads from the beginning
            "current_vehicle_pickup_visited": torch.zeros(
                *td["pickup_locs"].shape[:-1], 1, dtype=torch.bool, device=td.device
            ),  # used to denote whether the pickup is visited by the current vehicle
        },
        batch_size=batch_size,
    )

    # Compute action mask: mask out actions that are not allowed (e.g., pickup before delivery)
    td_reset.set("action_mask", self.get_action_mask(td_reset))
    return td_reset

Step

Environment _step: this defines the state update of the TSP problem gived a TensorDict (td in the code) of the current state and the action to take:

[5]:
def _step(self, td: TensorDict) -> TensorDict:
    # action: [batch_size, (pickups | deliveries | vehicles)]
    selected = td["action"]

    # The number of orders is set based on the number of vehicles and the capacity
    num_orders = self.num_vehicles * self.capacity

    # Identify the action type
    is_pickup = selected < num_orders  # bool mask -- [batch]
    is_delivery = (selected >= num_orders) & (selected < num_orders * 2)
    is_vehicle = selected >= num_orders * 2

    selected_pickups = selected[is_pickup]
    selected_deliveries = selected[is_delivery] - num_orders

    # ====  Pickup status update ====
    td["pickup_visited"][is_pickup, selected_pickups] = True

    # Update current load
    td["current_vehicle_loads"][is_pickup] += 1

    # Update history of pickup visited by the current vehicle
    td["current_vehicle_max_loads"][is_pickup] += 1
    td["current_vehicle_max_loads"][is_pickup].clamp_(max=self.capacity)

    # Set the paired deliveries' 'to_delivery' as True
    td["to_deliver"][is_pickup, selected_pickups] = True

    # Dropoff the load
    td["delivery_visited"][is_delivery, selected_deliveries] = True
    td["current_vehicle_loads"][is_delivery] -= 1

    # ====  Vehicle status update ====
    # Increment vehicle idx counter
    if td["i"][0] > 0:
        td["current_vehicle_idx"][is_vehicle] += 1
        td["current_vehicle_idx"][is_vehicle].clamp_(max=self.num_vehicles)
    # Initialize vehicle load and max loads
    td["current_vehicle_loads"][is_vehicle] = 0
    td["current_vehicle_max_loads"][is_vehicle] = 0

    done = td["pickup_visited"].all(dim=-1) & td["delivery_visited"].all(dim=-1)
    # The reward is calculated outside via get_reward for efficiency, so we set it to -inf here
    reward = torch.ones_like(done) * float("-inf")

    # Reti
    td_step = TensorDict(
        {
            "next": {
                "pickup_locs": td["pickup_locs"],
                "delivery_locs": td["delivery_locs"],
                "vehicle_locs": td["vehicle_locs"],
                "pickup_visited": td["pickup_visited"],
                "delivery_visited": td["delivery_visited"],
                "to_deliver": td["to_deliver"],
                "i": td["i"] + 1,
                "current_vehicle_idx": td["current_vehicle_idx"],
                "current_vehicle_loads": td["current_vehicle_loads"],
                "current_vehicle_max_loads": td["current_vehicle_max_loads"],
                "current_vehicle_pickup_visited": td[
                    "current_vehicle_pickup_visited"
                ],
                "current_node": selected,
                "done": done,
                "reward": reward,
            },
        },
        td.shape,
    )
    td_step["next"].set("action_mask", self.get_action_mask(td_step["next"]))
    return td_step

Action Mask

The get_action_mask function ensured that we only selected actions that were valid for the current state. This is important because we don’t want to select actions that are invalid - such as, we cannot deliver a package if we don’t have one or visit a location that we have already visited.

[6]:
def get_action_mask(self, td: TensorDict) -> torch.Tensor:
    # At the first decoding step, policy is allowed to select the first vehicle only.
    if td["i"][0] == 0:

        total_nodes = self.num_vehicles * self.capacity * 2 + self.num_vehicles

        action_mask = torch.zeros(
            *td.batch_size,
            total_nodes,
            dtype=torch.bool,
            device=td.device,
        )
        action_mask[:, self.num_vehicles * self.capacity * 2] = True
        return action_mask

    # Handling pickup action mask
    pickup_action_mask = ~td["pickup_visited"]
    pickup_action_mask[td["current_vehicle_max_loads"] == self.capacity] = False

    # If current vehicle carries "capacity" loads, it can only deliver
    pickup_action_mask[td["current_vehicle_loads"] >= self.capacity] = False

    # Handling delivery action mask
    delivery_action_mask = td["to_deliver"] & ~td["delivery_visited"]

    # Handling vehicle action mask
    # vehicle can be selected only if all scheduled deliveries are delivered
    # selecting vehicle indicates the end of the current vehicle's tour
    pd_action_mask = torch.cat([pickup_action_mask, delivery_action_mask], dim=-1)

    # Vehicle action mask becomes true only when the current vehicle finished all orders (i.e., the pairs of pickup and delivery)
    # here the next vehicle will be selected!
    vehicle_action_mask = (
        torch.nn.functional.one_hot(
            (td["current_vehicle_idx"] + 1).clamp(max=self.num_vehicles - 1),
            num_classes=self.num_vehicles,
        )
        .bool()
    )

    # "Terminate" action is allowed if
    # (1) the current vehicle is not able to visit any pickup, and
    # (2) the scheduled deliveries are all delivered
    vehicle_action_mask[pd_action_mask.sum(dim=-1) > 0] = False

    # Action mask is true when the action is allowed (i.e. feasible)
    action_mask = torch.cat([pd_action_mask, vehicle_action_mask], dim=-1)
    return action_mask

Reward function

The get_reward function is used to evaluate the reward given the solution (actions).

[7]:
def get_reward(self, td: TensorDict, actions: torch.Tensor) -> torch.Tensor:
    locs = torch.cat(
        [td["pickup_locs"], td["delivery_locs"], td["vehicle_locs"]], dim=-2
    )
    ordered_locs = gather_by_index(locs, actions)

    if ordered_locs.dim() == 2:  # batch size = 1
        ordered_locs = ordered_locs[None, ...]

    # Reorder the tours in two dimensions
    ordered_locs = rearrange(
        ordered_locs,
        "... (n c) two -> ... n c two",
        n=self.num_vehicles,  # number of vehicles
        c=self.capacity * 2 + 1,
        two=2,
    )  # batch, num vehicles, capacity, 2

    dists = (ordered_locs[..., :-1, :] - ordered_locs[..., 1:, :]).norm(p=2, dim=-1)
    dists = dists.sum(dim=(-1, -2))  # [batch]
    return -dists # negative distance is the reward

Environment Action Specs

This defines the input and output domains of the environment - similar to Gym’s spaces. This is not strictly necessary, but it is useful to have a clear definition of the environment’s action and observation spaces and if we want to sample actions using TorchRL’s utils

[8]:
def _make_spec(self, td_params: TensorDict = None):
    self.observation_spec = CompositeSpec(
        delivery_locs=BoundedTensorSpec(
            minimum=self.min_loc,
            maximum=self.max_loc,
            shape=(self.num_vehicles * self.capacity, 2),
        ),
        pickup_locs=BoundedTensorSpec(
            minimum=self.min_loc,
            maximum=self.max_loc,
            shape=(self.num_vehicles * self.capacity, 2),
        ),
        vehicle_locs=BoundedTensorSpec(
            minimum=self.min_loc,
            maximum=self.max_loc,
            shape=(self.num_vehicles, 2),
        ),
        pickup_visited=UnboundedDiscreteTensorSpec(
            shape=(self.num_vehicles * self.capacity), dtype=torch.int64
        ),
        delivery_visited=UnboundedDiscreteTensorSpec(
            shape=(self.num_vehicles * self.capacity), dtype=torch.int64
        ),
        to_deliver=UnboundedDiscreteTensorSpec(
            shape=(self.num_vehicles * self.capacity), dtype=torch.int64
        ),
        current_vehicle_id=UnboundedDiscreteTensorSpec(shape=(1), dtype=torch.int64),
        current_vehicle_loads=UnboundedDiscreteTensorSpec(
            shape=(1), dtype=torch.int64
        ),
        current_vehicle_max_loads=UnboundedDiscreteTensorSpec(
            shape=(1), dtype=torch.int64
        ),
        current_vehicle_pickup_visited=UnboundedDiscreteTensorSpec(
            shape=(self.num_vehicles * self.capacity), dtype=torch.int64
        ),
    )

    self.input_spec = self.observation_spec.clone()
    self.action_spec = BoundedTensorSpec(
        shape=(1,),
        dtype=torch.int64,
        minimum=0,
        maximum=self.num_vehicles + self.num_vehicles * self.capacity * 2,
    )
    self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,))
    self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool)

Data generation

This function allows for generating data for training instances if no data is provided.

[9]:
def generate_data(self, batch_size) -> TensorDict:
    batch_size = [batch_size] if isinstance(batch_size, int) else batch_size

    # Generate random locations
    pickup_locs = (
        torch.FloatTensor(*batch_size, self.num_vehicles * self.capacity, 2)
        .uniform_(self.min_loc, self.max_loc)
        .to(self.device)
    )

    # Here we generate pickup and delivery locations not far apart from each other
    rs = (
        torch.FloatTensor(*batch_size, self.num_vehicles * self.capacity)
        .uniform_(0.0, 0.2)
        .to(self.device)
    )
    thetas = (
        torch.FloatTensor(*batch_size, self.num_vehicles * self.capacity)
        .uniform_(0.0, 360.0)
        .to(self.device)
    )

    delta_x = rs * torch.cos(thetas)
    delta_y = rs * torch.sin(thetas)

    delivery_locs = pickup_locs.clone() + torch.stack([delta_x, delta_y], dim=-1)

    vehicle_locs = (
        torch.FloatTensor(*batch_size, self.num_vehicles, 2)
        .uniform_(self.min_loc, self.max_loc)
        .to(self.device)
    )

    return TensorDict(
        {
            "pickup_locs": pickup_locs,
            "delivery_locs": delivery_locs,
            "vehicle_locs": vehicle_locs,
            "first_node": torch.zeros(
                *batch_size, dtype=torch.int64, device=self.device
            ),
        },
        batch_size=batch_size,
    )

Render function

The render function is optional, but can be useful for quickly visualizing the results of your algorithm!

[10]:
def render(
    self,
    td: TensorDict,
    actions: torch.Tensor = None,
    ax: Axes = None,
    batch_idx: int = None,
):
    import matplotlib.pyplot as plt

    def draw_line(src, dst, ax):
        ax.plot([src[0], dst[0]], [src[1], dst[1]], ls="--", c="gray")

    td = td.detach().cpu()

    if actions is None:
        actions = td.get("action", None)

    if td.batch_size != torch.Size([]):
        batch_idx = 0 if batch_idx is None else batch_idx
        td = td[0]
        actions = actions[0]

    pickup_locs = td["pickup_locs"]
    delivery_locs = td["delivery_locs"]
    vehicle_locs = td["vehicle_locs"]

    if ax is None:
        # Create a plot of the nodes
        _, ax = plt.subplots(1, 1, figsize=(6, 6))

    ax.axis("equal"); ax.grid(True)

    ax.scatter(pickup_locs[:, 0], pickup_locs[:, 1], marker="x", color="gray")
    ax.scatter(delivery_locs[:, 0], delivery_locs[:, 1], marker="^", color="gray")

    for v_i, (v_x, v_y) in enumerate(zip(vehicle_locs[:, 0], vehicle_locs[:, 1])):
        ax.scatter(v_x, v_y, color=f"C{v_i}")

    for p_loc, d_loc in zip(pickup_locs, delivery_locs):
        draw_line(p_loc, d_loc, ax)

    if actions is not None:  # draw solution if available.
        sub_tours = actions.reshape(self.num_vehicles, -1)
        loc = torch.cat([pickup_locs, delivery_locs, vehicle_locs], dim=0)
        for v_i, sub_tour in enumerate(sub_tours):
            ax.plot(loc[sub_tour][:, 0], loc[sub_tour][:, 1], color=f"C{v_i}")

Putting everything together

[11]:
class MDPDPEnv(RL4COEnvBase):
    """Multi-Depot Pickup and Delivery Problem (MDPDP) environment
    We consider the number of delivery locations to be num_vehicles * capacity,
    with one corresponding pickup location for each delivery location.
    The total number of locations will be:
        - num_vehicles + (2 * num_vehicles * capacity)
    """

    name = "MDPDP"

    def __init__(
        self,
        num_vehicles: int = 5,
        capacity: int = 3,
        min_loc: float = 0.0,
        max_loc: float = 1.0,
        td_params: TensorDict = None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.num_vehicles = num_vehicles
        self.capacity = capacity
        self.min_loc = min_loc
        self.max_loc = max_loc
        self._make_spec(td_params)

    _reset = _reset
    _step = _step
    get_reward = get_reward
    get_action_mask = get_action_mask
    _make_spec = _make_spec
    generate_data = generate_data
    render = render

[12]:
batch_size = 2

env = MDPDPEnv(num_vehicles=5, capacity=3)
reward, td, actions = rollout(env, env.reset(batch_size=[batch_size]), random_policy)
env.render(td, actions)
../../_images/_collections_tutorials_2-solving-new-problem_26_0.png

Modeling

Now we need to model the problem by transforming input information into the latent space to be processed. In RL4CO, we divide embeddings in 3 parts:

  • init_embedding: embed initial states of the problem

  • context_embedding: embed context information of the problem for the current partial solution to modify the query

  • dynamic_embedding: embed dynamic information of the problem for the current partial solution to modify the query, key, and value (i.e. if other nodes also change state)

Init Embedding

Embed initial problem into latent space. In our case, we project the vehicle, pickup, and delivery locations

[13]:
class MDPDPInitEmbedding(nn.Module):
    def __init__(self, embedding_dim: int):
        super().__init__()
        node_dim = 2  # x, y
        self.init_embed_pick = nn.Linear(node_dim * 2, embedding_dim)
        self.init_embed_delivery = nn.Linear(node_dim, embedding_dim)
        self.init_embed_vehicle = nn.Linear(2, embedding_dim)

    def forward(self, td: TensorDict):
        pickup_emb = self.init_embed_pick(
            torch.cat([td["pickup_locs"], td["delivery_locs"]], dim=-1)
        )
        delivery_emb = self.init_embed_delivery(td["delivery_locs"])
        vehicle_emb = self.init_embed_vehicle(td["vehicle_locs"])
        return torch.cat([pickup_emb, delivery_emb, vehicle_emb], dim=-2)

Context Embedding

Context embedding takes current context and returns a vector representation of it. In the MDPDP, the context is equivalent to the current node embedding

[14]:
class MDPDPContextEmbedding(nn.Module):
    def __init__(self, embedding_dim, step_context_dim=None, linear_bias=False):
        super(MDPDPContextEmbedding, self).__init__()
        self.embedding_dim = embedding_dim
        self.W_placeholder = nn.Parameter(torch.Tensor(self.embedding_dim).uniform_(-1, 1))
        self.project_context = nn.Linear(
            embedding_dim, embedding_dim, bias=linear_bias
        )

    def _cur_node_embedding(self, embeddings, td):
        if td["i"][0].item() == 0:
            batch_size = embeddings.size(0)
            context_embedding = self.W_placeholder[None, :].expand(
                batch_size, self.W_placeholder.size(-1)
            )
            return context_embedding

        cur_node_embedding = gather_by_index(embeddings, td["current_node"])
        return cur_node_embedding

    def forward(self, embeddings, td):
        cur_node_embedding = self._cur_node_embedding(embeddings, td).squeeze()
        return self.project_context(cur_node_embedding)

Dynamic Embedding

Since the states do not change except for visited nodes, we do not need to modify the keys and values. Therefore, we set this to 0

[15]:
class StaticEmbedding(nn.Module):
    def __init__(self, *args, **kwargs):
        super(StaticEmbedding, self).__init__()

    def forward(self, td):
        return 0, 0, 0

Training our Model

[16]:
# Instantiate our environment
env = MDPDPEnv(num_vehicles=5, capacity=3)

# Instantiate policy with the embeddings we created above
emb_dim = 128
policy = AutoregressivePolicy(env,
                              embedding_dim=emb_dim,
                              init_embedding=MDPDPInitEmbedding(emb_dim),
                              context_embedding=MDPDPContextEmbedding(emb_dim),
                              dynamic_embedding=StaticEmbedding(emb_dim)
)


# Model: default is AM with REINFORCE and greedy rollout baseline
model = AttentionModel(env,
                       policy=policy,
                       baseline='rollout',
                       train_data_size=100_000,
                       val_data_size=10_000)
/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:196: UserWarning: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`.
  rank_zero_warn(
/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:196: UserWarning: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`.
  rank_zero_warn(

Rollout untrained model

[17]:
# Greedy rollouts over untrained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
td_init = env.reset(batch_size=[3]).to(device)
model = model.to(device)
out = model(td_init.clone(), phase="test", decode_type="greedy", return_actions=True)

# Plotting
print(f"Tour lengths: {[f'{-r.item():.2f}' for r in out['reward']]}")
for td, actions in zip(td_init, out['actions'].cpu()):
    env.render(td, actions)
Tour lengths: ['7.96', '7.60', '10.26']
../../_images/_collections_tutorials_2-solving-new-problem_37_1.png
../../_images/_collections_tutorials_2-solving-new-problem_37_2.png
../../_images/_collections_tutorials_2-solving-new-problem_37_3.png

Training loop

[18]:
# We use our own wrapper around Lightning's `Trainer` to make it easier to use
trainer = RL4COTrainer(max_epochs=3)
trainer.fit(model)
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
val_file not set. Generating dataset instead
test_file not set. Generating dataset instead
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type                 | Params
--------------------------------------------------
0 | env      | MDPDPEnv             | 0
1 | policy   | AutoregressivePolicy | 694 K
2 | baseline | WarmupBaseline       | 694 K
--------------------------------------------------
1.4 M     Trainable params
0         Non-trainable params
1.4 M     Total params
5.557     Total estimated model params size (MB)
/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:432: PossibleUserWarning: The dataloader, val_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 32 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/botu/miniconda3/envs/rl4co/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:432: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 32 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
`Trainer.fit` stopped: `max_epochs=3` reached.

Evaluation

[19]:
# Greedy rollouts over trained model (same states as previous plot)
model = model.to(device)
out = model(td_init, phase="test", decode_type="greedy", return_actions=True)

# Plotting
print(f"Tour lengths: {[f'{-r.item():.2f}' for r in out['reward']]}")
for td, actions in zip(td_init, out['actions'].cpu()):
    env.render(td, actions)
Tour lengths: ['4.82', '5.53', '5.22']
../../_images/_collections_tutorials_2-solving-new-problem_41_1.png
../../_images/_collections_tutorials_2-solving-new-problem_41_2.png
../../_images/_collections_tutorials_2-solving-new-problem_41_3.png

We can see that solutions are way better than with the untrained model, even just after 3 epochs! 🚀