Shortcuts

nn Modules

Graph Neural Networks

Graph Attention Encoder

class rl4co.models.nn.graph.attnnet.GraphAttentionNetwork(num_heads, embedding_dim, num_layers, normalization='batch', feed_forward_hidden=512, sdpa_fn=None)[source]

Bases: Module

Graph Attention Network to encode embeddings with a series of MHA layers consisting of a MHA layer, normalization, feed-forward layer, and normalization. Similar to Transformer encoder, as used in Kool et al. (2019).

Parameters:
  • num_heads (int) – number of heads in the MHA

  • embedding_dim (int) – dimension of the embeddings

  • num_layers (int) – number of MHA layers

  • normalization (str) – type of normalization to use (batch, layer, none)

  • feed_forward_hidden (int) – dimension of the hidden layer in the feed-forward layer

  • sdpa_fn (Optional[Callable]) – scaled dot product attention function (SDPA)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x, mask=None)[source]

Forward pass of the encoder

Parameters:
  • x (Tensor) – [batch_size, graph_size, embed_dim] initial embeddings to process

  • mask (Optional[Tensor]) – [batch_size, graph_size, graph_size] mask for the input embeddings. Unused for now.

Return type:

Tensor

class rl4co.models.nn.graph.attnnet.MultiHeadAttentionLayer(num_heads, embed_dim, feed_forward_hidden=512, normalization='batch', sdpa_fn=None)[source]

Bases: Sequential

Multi-Head Attention Layer with normalization and feed-forward layer

Parameters:
  • num_heads (int) – number of heads in the MHA

  • embed_dim (int) – dimension of the embeddings

  • feed_forward_hidden (int) – dimension of the hidden layer in the feed-forward layer

  • normalization (Optional[str]) – type of normalization to use (batch, layer, none)

  • sdpa_fn (Optional[Callable]) – scaled dot product attention function (SDPA)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Graph Convolutional Encoder

class rl4co.models.nn.graph.gcn.GCNEncoder(env_name, embedding_dim, num_nodes, num_layers, init_embedding=None, self_loop=False, residual=True)[source]

Bases: Module

Graph Convolutional Network to encode embeddings with a series of GCN layers

Parameters:
  • embedding_dim (int) – dimension of the embeddings

  • num_nodes (int) – number of nodes in the graph

  • num_gcn_layer – number of GCN layers

  • self_loop (bool) – whether to add self loop in the graph

  • residual (bool) – whether to use residual connection

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(td, mask=None)[source]

Forward pass of the encoder. Transform the input TensorDict into a latent representation.

Parameters:
  • td (TensorDict) – Input TensorDict containing the environment state

  • mask (Optional[Tensor]) – Mask to apply to the attention

Returns:

Latent representation of the input init_h: Initial embedding of the input

Return type:

h

Message Passing Encoder

class rl4co.models.nn.graph.mpnn.MessagePassingEncoder(env_name, embedding_dim, num_nodes, num_layers, init_embedding=None, aggregation='add', self_loop=False, residual=True)[source]

Bases: Module

Note

  • Support fully connected graph for now.

edge_update(nf, ef, edge_index)[source]
forward(td, mask=None)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tuple[Tensor, Tensor]

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

message(edge_features)[source]
update(aggr_msg, x)[source]
class rl4co.models.nn.graph.mpnn.MessagePassingLayer(node_indim, node_outdim, edge_indim, edge_outdim, aggregation='add', residual=False, **mlp_params)[source]

Bases: MessagePassing

Initializes internal Module state, shared by both nn.Module and ScriptModule.

edge_update(nf, ef, edge_index)[source]

Computes or updates features for each edge in the graph. This function can take any argument as input which was initially passed to edge_updater(). Furthermore, tensors passed to edge_updater() can be mapped to the respective nodes \(i\) and \(j\) by appending _i or _j to the variable name, .e.g. x_i and x_j.

forward(node_feature, edge_feature, edge_index, mask=None)[source]

Runs the forward pass of the module.

message(edge_features)[source]

Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in edge_index. This function can take any argument as input which was initially passed to propagate(). Furthermore, tensors passed to propagate() can be mapped to the respective nodes \(i\) and \(j\) by appending _i or _j to the variable name, .e.g. x_i and x_j.

update(aggr_msg, x)[source]

Updates node embeddings in analogy to \(\gamma_{\mathbf{\Theta}}\) for each node \(i \in \mathcal{V}\). Takes in the output of aggregation as first argument and any argument which was initially passed to propagate().

rl4co.models.nn.attention

class rl4co.models.nn.attention.LogitAttention(embed_dim, num_heads, tanh_clipping=10.0, mask_inner=True, mask_logits=True, normalize=True, softmax_temp=1.0, linear_bias=False, sdp_fn=<built-in function scaled_dot_product_attention>)[source]

Bases: Module

Calculate logits given query, key and value and logit key.

Note

With Flash Attention, masking is not supported

Perform the following:
  1. Apply cross attention to get the heads

  2. Project heads to get glimpse

  3. Compute attention score between glimpse and logit key

  4. Normalize and mask

Parameters:
  • embed_dim (int) – total dimension of the model

  • num_heads (int) – number of heads

  • tanh_clipping (float) – tanh clipping value

  • mask_inner (bool) – whether to mask inner attention

  • mask_logits (bool) – whether to mask logits

  • normalize (bool) – whether to normalize logits

  • softmax_temp (float) – softmax temperature

  • linear_bias (bool) – whether to use bias in linear projection

  • sdp_fn – scaled dot product attention function (SDPA)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(query, key, value, logit_key, mask, softmax_temp=None)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class rl4co.models.nn.attention.MultiHeadAttention(embed_dim, num_heads, bias=True, attention_dropout=0.0, causal=False, device=None, dtype=None, sdpa_fn=None)[source]

Bases: Module

PyTorch native implementation of Flash Multi-Head Attention with automatic mixed precision support. Uses PyTorch’s native scaled_dot_product_attention implementation, available from 2.0

Note

If scaled_dot_product_attention is not available, use custom implementation of scaled_dot_product_attention without Flash Attention.

Parameters:
  • embed_dim (int) – total dimension of the model

  • num_heads (int) – number of heads

  • bias (bool) – whether to use bias

  • attention_dropout (float) – dropout rate for attention weights

  • causal (bool) – whether to apply causal mask to attention scores

  • device (str) – torch device

  • dtype (dtype) – torch dtype

  • sdpa_fn (Optional[Callable]) – scaled dot product attention function (SDPA)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x, key_padding_mask=None)[source]

x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) key_padding_mask: bool tensor of shape (batch, seqlen)

rl4co.models.nn.attention.scaled_dot_product_attention_simple(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)[source]

Simple Scaled Dot-Product Attention in PyTorch without Flash Attention

rl4co.models.nn.flash_attention

rl4co.models.nn.mlp

class rl4co.models.nn.mlp.MLP(input_dim, output_dim, num_neurons=[64, 32], hidden_act='ReLU', out_act='Identity', input_norm='None', output_norm='None')[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(xs)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

rl4co.models.nn.ops

class rl4co.models.nn.ops.Normalization(embed_dim, normalization='batch')[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

init_parameters()[source]
class rl4co.models.nn.ops.SkipConnection(module)[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

rl4co.models.nn.utils

class rl4co.models.nn.utils.RandomPolicy(env_name=None)[source]

Bases: Module

Random Policy Class that randomly select actions from the action space This policy can be useful to check the sanity of the environment during development

We match the function signature of forward to the one of the AutoregressivePolicy class

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(td, env=None, max_steps=None)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

rl4co.models.nn.utils.decode_probs(probs, mask, decode_type='sampling')[source]

Decode probabilities to select actions with mask

rl4co.models.nn.utils.get_log_likelihood(log_p, actions, mask, return_sum=True)[source]

Get log likelihood of selected actions

rl4co.models.nn.utils.random_policy(td)[source]

Helper function to select a random action from available actions

rl4co.models.nn.utils.rollout(env, td, policy, max_steps=None)[source]

Helper function to rollout a policy. Currently, TorchRL does not allow to step over envs when done with env.rollout(). We need this because for environments that complete at different steps.