nn Modules¶
Graph Neural Networks¶
Graph Attention Encoder¶
- class rl4co.models.nn.graph.attnnet.GraphAttentionNetwork(num_heads, embedding_dim, num_layers, normalization='batch', feed_forward_hidden=512, sdpa_fn=None)[source]¶
Bases:
ModuleGraph Attention Network to encode embeddings with a series of MHA layers consisting of a MHA layer, normalization, feed-forward layer, and normalization. Similar to Transformer encoder, as used in Kool et al. (2019).
- Parameters:
num_heads¶ (
int) – number of heads in the MHAembedding_dim¶ (
int) – dimension of the embeddingsnum_layers¶ (
int) – number of MHA layersnormalization¶ (
str) – type of normalization to use (batch, layer, none)feed_forward_hidden¶ (
int) – dimension of the hidden layer in the feed-forward layersdpa_fn¶ (
Optional[Callable]) – scaled dot product attention function (SDPA)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- class rl4co.models.nn.graph.attnnet.MultiHeadAttentionLayer(num_heads, embed_dim, feed_forward_hidden=512, normalization='batch', sdpa_fn=None)[source]¶
Bases:
SequentialMulti-Head Attention Layer with normalization and feed-forward layer
- Parameters:
num_heads¶ (
int) – number of heads in the MHAembed_dim¶ (
int) – dimension of the embeddingsfeed_forward_hidden¶ (
int) – dimension of the hidden layer in the feed-forward layernormalization¶ (
Optional[str]) – type of normalization to use (batch, layer, none)sdpa_fn¶ (
Optional[Callable]) – scaled dot product attention function (SDPA)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Graph Convolutional Encoder¶
- class rl4co.models.nn.graph.gcn.GCNEncoder(env_name, embedding_dim, num_nodes, num_layers, init_embedding=None, self_loop=False, residual=True)[source]¶
Bases:
ModuleGraph Convolutional Network to encode embeddings with a series of GCN layers
- Parameters:
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Message Passing Encoder¶
- class rl4co.models.nn.graph.mpnn.MessagePassingEncoder(env_name, embedding_dim, num_nodes, num_layers, init_embedding=None, aggregation='add', self_loop=False, residual=True)[source]¶
Bases:
ModuleNote
Support fully connected graph for now.
- forward(td, mask=None)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tuple[Tensor,Tensor]Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class rl4co.models.nn.graph.mpnn.MessagePassingLayer(node_indim, node_outdim, edge_indim, edge_outdim, aggregation='add', residual=False, **mlp_params)[source]¶
Bases:
MessagePassingInitializes internal Module state, shared by both nn.Module and ScriptModule.
- edge_update(nf, ef, edge_index)[source]¶
Computes or updates features for each edge in the graph. This function can take any argument as input which was initially passed to
edge_updater(). Furthermore, tensors passed toedge_updater()can be mapped to the respective nodes \(i\) and \(j\) by appending_ior_jto the variable name, .e.g.x_iandx_j.
- forward(node_feature, edge_feature, edge_index, mask=None)[source]¶
Runs the forward pass of the module.
- message(edge_features)[source]¶
Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in
edge_index. This function can take any argument as input which was initially passed topropagate(). Furthermore, tensors passed topropagate()can be mapped to the respective nodes \(i\) and \(j\) by appending_ior_jto the variable name, .e.g.x_iandx_j.
rl4co.models.nn.attention¶
- class rl4co.models.nn.attention.LogitAttention(embed_dim, num_heads, tanh_clipping=10.0, mask_inner=True, mask_logits=True, normalize=True, softmax_temp=1.0, linear_bias=False, sdp_fn=<built-in function scaled_dot_product_attention>)[source]¶
Bases:
ModuleCalculate logits given query, key and value and logit key.
Note
With Flash Attention, masking is not supported
- Perform the following:
Apply cross attention to get the heads
Project heads to get glimpse
Compute attention score between glimpse and logit key
Normalize and mask
- Parameters:
embed_dim¶ (
int) – total dimension of the modelnum_heads¶ (
int) – number of headstanh_clipping¶ (
float) – tanh clipping valuemask_inner¶ (
bool) – whether to mask inner attentionmask_logits¶ (
bool) – whether to mask logitsnormalize¶ (
bool) – whether to normalize logitssoftmax_temp¶ (
float) – softmax temperaturelinear_bias¶ (
bool) – whether to use bias in linear projectionsdp_fn¶ – scaled dot product attention function (SDPA)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(query, key, value, logit_key, mask, softmax_temp=None)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class rl4co.models.nn.attention.MultiHeadAttention(embed_dim, num_heads, bias=True, attention_dropout=0.0, causal=False, device=None, dtype=None, sdpa_fn=None)[source]¶
Bases:
ModulePyTorch native implementation of Flash Multi-Head Attention with automatic mixed precision support. Uses PyTorch’s native scaled_dot_product_attention implementation, available from 2.0
Note
If scaled_dot_product_attention is not available, use custom implementation of scaled_dot_product_attention without Flash Attention.
- Parameters:
embed_dim¶ (
int) – total dimension of the modelnum_heads¶ (
int) – number of headsbias¶ (
bool) – whether to use biasattention_dropout¶ (
float) – dropout rate for attention weightscausal¶ (
bool) – whether to apply causal mask to attention scoresdevice¶ (
str) – torch devicedtype¶ (
dtype) – torch dtypesdpa_fn¶ (
Optional[Callable]) – scaled dot product attention function (SDPA)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
rl4co.models.nn.flash_attention¶
rl4co.models.nn.mlp¶
- class rl4co.models.nn.mlp.MLP(input_dim, output_dim, num_neurons=[64, 32], hidden_act='ReLU', out_act='Identity', input_norm='None', output_norm='None')[source]¶
Bases:
ModuleInitializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(xs)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
rl4co.models.nn.ops¶
- class rl4co.models.nn.ops.Normalization(embed_dim, normalization='batch')[source]¶
Bases:
ModuleInitializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(x)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class rl4co.models.nn.ops.SkipConnection(module)[source]¶
Bases:
ModuleInitializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(x)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
rl4co.models.nn.utils¶
- class rl4co.models.nn.utils.RandomPolicy(env_name=None)[source]¶
Bases:
ModuleRandom Policy Class that randomly select actions from the action space This policy can be useful to check the sanity of the environment during development
We match the function signature of forward to the one of the AutoregressivePolicy class
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(td, env=None, max_steps=None)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- rl4co.models.nn.utils.decode_probs(probs, mask, decode_type='sampling')[source]¶
Decode probabilities to select actions with mask
- rl4co.models.nn.utils.get_log_likelihood(log_p, actions, mask, return_sum=True)[source]¶
Get log likelihood of selected actions