Shortcuts

Source code for rl4co.models.nn.env_embeddings.context

import torch
import torch.nn as nn

from rl4co.utils.ops import gather_by_index


[docs]def env_context_embedding(env_name: str, config: dict) -> nn.Module: """Get environment context embedding. The context embedding is used to modify the query embedding of the problem node of the current partial solution. Usually consists of a projection of gathered node embeddings and features to the embedding space. Args: env: Environment or its name. config: A dictionary of configuration options for the environment. """ embedding_registry = { "tsp": TSPContext, "atsp": TSPContext, "cvrp": VRPContext, "sdvrp": VRPContext, "pctsp": PCTSPContext, "spctsp": PCTSPContext, "op": OPContext, "dpp": DPPContext, "mdpp": DPPContext, "pdp": PDPContext, "mtsp": MTSPContext, "smtwtp": SMTWTPContext, } if env_name not in embedding_registry: raise ValueError( f"Unknown environment name '{env_name}'. Available context embeddings: {embedding_registry.keys()}" ) return embedding_registry[env_name](**config)
[docs]class EnvContext(nn.Module): """Base class for environment context embeddings. The context embedding is used to modify the query embedding of the problem node of the current partial solution. Consists of a linear layer that projects the node features to the embedding space.""" def __init__(self, embedding_dim, step_context_dim=None, linear_bias=False): super(EnvContext, self).__init__() self.embedding_dim = embedding_dim step_context_dim = ( step_context_dim if step_context_dim is not None else embedding_dim ) self.project_context = nn.Linear( step_context_dim, embedding_dim, bias=linear_bias ) def _cur_node_embedding(self, embeddings, td): """Get embedding of current node""" cur_node_embedding = gather_by_index(embeddings, td["current_node"]) return cur_node_embedding def _state_embedding(self, embeddings, td): """Get state embedding""" raise NotImplementedError("Implement for each environment")
[docs] def forward(self, embeddings, td): cur_node_embedding = self._cur_node_embedding(embeddings, td) state_embedding = self._state_embedding(embeddings, td) context_embedding = torch.cat([cur_node_embedding, state_embedding], -1) return self.project_context(context_embedding)
[docs]class TSPContext(EnvContext): """Context embedding for the Traveling Salesman Problem (TSP). Project the following to the embedding space: - first node embedding - current node embedding """ def __init__(self, embedding_dim): super(TSPContext, self).__init__(embedding_dim, 2 * embedding_dim) self.W_placeholder = nn.Parameter( torch.Tensor(2 * self.embedding_dim).uniform_(-1, 1) )
[docs] def forward(self, embeddings, td): batch_size = embeddings.size(0) # By default, node_dim = -1 (we only have one node embedding per node) node_dim = ( (-1,) if td["first_node"].dim() == 1 else (td["first_node"].size(-1), -1) ) if td["i"][(0,) * td["i"].dim()].item() < 1: # get first item fast context_embedding = self.W_placeholder[None, :].expand( batch_size, self.W_placeholder.size(-1) ) else: context_embedding = gather_by_index( embeddings, torch.stack([td["first_node"], td["current_node"]], -1).view( batch_size, -1 ), ).view(batch_size, *node_dim) return self.project_context(context_embedding)
[docs]class VRPContext(EnvContext): """Context embedding for the Capacitated Vehicle Routing Problem (CVRP). Project the following to the embedding space: - current node embedding - remaining capacity (vehicle_capacity - used_capacity) """ def __init__(self, embedding_dim): super(VRPContext, self).__init__(embedding_dim, embedding_dim + 1) def _state_embedding(self, embeddings, td): state_embedding = td["vehicle_capacity"] - td["used_capacity"] return state_embedding
[docs]class PCTSPContext(EnvContext): """Context embedding for the Prize Collecting TSP (PCTSP). Project the following to the embedding space: - current node embedding - remaining prize (prize_required - cur_total_prize) """ def __init__(self, embedding_dim): super(PCTSPContext, self).__init__(embedding_dim, embedding_dim + 1) def _state_embedding(self, embeddings, td): state_embedding = torch.clamp( td["prize_required"] - td["cur_total_prize"], min=0 )[..., None] return state_embedding
[docs]class OPContext(EnvContext): """Context embedding for the Orienteering Problem (OP). Project the following to the embedding space: - current node embedding - remaining distance (max_length - tour_length) """ def __init__(self, embedding_dim): super(OPContext, self).__init__(embedding_dim, embedding_dim + 1) def _state_embedding(self, embeddings, td): state_embedding = td["max_length"][..., 0] - td["tour_length"] return state_embedding[..., None]
[docs]class DPPContext(EnvContext): """Context embedding for the Decap Placement Problem (DPP), EDA (electronic design automation). Project the following to the embedding space: - current cell embedding """ def __init__(self, embedding_dim): super(DPPContext, self).__init__(embedding_dim)
[docs] def forward(self, embeddings, td): """Context cannot be defined by a single node embedding for DPP, hence 0. We modify the dynamic embedding instead to capture placed items """ return embeddings.new_zeros(embeddings.size(0), self.embedding_dim)
[docs]class PDPContext(EnvContext): """Context embedding for the Pickup and Delivery Problem (PDP). Project the following to the embedding space: - current node embedding """ def __init__(self, embedding_dim): super(PDPContext, self).__init__(embedding_dim, embedding_dim)
[docs] def forward(self, embeddings, td): cur_node_embedding = self._cur_node_embedding(embeddings, td).squeeze() return self.project_context(cur_node_embedding)
[docs]class MTSPContext(EnvContext): """Context embedding for the Multiple Traveling Salesman Problem (mTSP). Project the following to the embedding space: - current node embedding - remaining_agents - current_length - max_subtour_length - distance_from_depot """ def __init__(self, embedding_dim, linear_bias=False): super(MTSPContext, self).__init__(embedding_dim, 2 * embedding_dim) proj_in_dim = ( 4 # remaining_agents, current_length, max_subtour_length, distance_from_depot ) self.proj_dynamic_feats = nn.Linear(proj_in_dim, embedding_dim, bias=linear_bias) def _cur_node_embedding(self, embeddings, td): cur_node_embedding = gather_by_index(embeddings, td["current_node"]) return cur_node_embedding.squeeze() def _state_embedding(self, embeddings, td): dynamic_feats = torch.stack( [ (td["num_agents"] - td["agent_idx"]).float(), td["current_length"], td["max_subtour_length"], self._distance_from_depot(td), ], dim=-1, ) return self.proj_dynamic_feats(dynamic_feats) def _distance_from_depot(self, td): # Euclidean distance from the depot (loc[..., 0, :]) cur_loc = gather_by_index(td["locs"], td["current_node"]) return torch.norm(cur_loc - td["locs"][..., 0, :], dim=-1)
[docs]class SMTWTPContext(EnvContext): """Context embedding for the Single Machine Total Weighted Tardiness Problem (SMTWTP). Project the following to the embedding space: - current node embedding - current time """ def __init__(self, embedding_dim): super(SMTWTPContext, self).__init__(embedding_dim, embedding_dim + 1) def _cur_node_embedding(self, embeddings, td): cur_node_embedding = gather_by_index(embeddings, td["current_job"]) return cur_node_embedding def _state_embedding(self, embeddings, td): state_embedding = td["current_time"] return state_embedding

© Copyright Federico Berto, Chuanbo Hua, Junyoung Park. Revision 14d072ed.

Built with Sphinx using a theme provided by Read the Docs.