Shortcuts

Source code for rl4co.models.rl.common.critic

from typing import Callable, Optional, Union

from tensordict import TensorDict
from torch import Tensor, nn

from rl4co.envs import RL4COEnvBase
from rl4co.models.nn.env_embeddings import env_init_embedding
from rl4co.models.nn.graph.attnnet import GraphAttentionNetwork


[docs]class CriticNetwork(nn.Module): """We make the critic network compatible with any problem by using encoder for any environment Refactored from Kool et al. (2019) which only worked for TSP. In our case, we make it compatible with any problem by using the environment init embedding. Note that if no environment name and no init embedding are provided, the critic network does not transform the input (i.e. it should be a tensor of shape (batch_size, embedding_dim)). Args: env_name: environment name to solve encoder: Encoder to use for the critic init_embedding: Initial embedding to use for the critic embedding_dim: Dimension of the embeddings hidden_dim: Hidden dimension for the feed-forward network num_layers: Number of layers for the encoder num_heads: Number of heads for the attention normalization: Normalization to use for the attention sdpa_fn: Scaled dot product function to use for the attention """ def __init__( self, env_name: str = None, encoder: nn.Module = None, init_embedding: nn.Module = None, embedding_dim: int = 128, hidden_dim: int = 512, num_layers: int = 3, num_heads: int = 8, normalization: str = "batch", sdpa_fn: Optional[Callable] = None, **unused_kwargs, ): super(CriticNetwork, self).__init__() if isinstance(env_name, RL4COEnvBase): env_name = env_name.name self.env_name = env_name if env_name is None and init_embedding is None: self.init_embedding = nn.Identity() # No embedding else: self.init_embedding = ( env_init_embedding(self.env_name, {"embedding_dim": embedding_dim}) if init_embedding is None else init_embedding ) self.encoder = ( GraphAttentionNetwork( num_heads=num_heads, embedding_dim=embedding_dim, num_layers=num_layers, normalization=normalization, feed_forward_hidden=hidden_dim, sdpa_fn=sdpa_fn, ) if encoder is None else encoder ) self.value_head = nn.Sequential( nn.Linear(embedding_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1) )
[docs] def forward(self, x: Union[Tensor, TensorDict]) -> Tensor: """Forward pass of the critic network: encode the imput in embedding space and return the value Args: x: Input containing the environment state. Can be a Tensor or a TensorDict Returns: Value of the input state """ # Initial embedding of x. This is the identity function if env_name is None. x = self.init_embedding(x) x = self.encoder(x) return self.value_head(x).mean(1)

© Copyright Federico Berto, Chuanbo Hua, Junyoung Park. Revision 14d072ed.

Built with Sphinx using a theme provided by Read the Docs.