nn Modules¶
Graph Neural Networks¶
Graph Attention Encoder¶
- class rl4co.models.nn.graph.attnnet.GraphAttentionNetwork(num_heads, embedding_dim, num_layers, normalization='batch', feed_forward_hidden=512, force_flash_attn=False)[source]¶
Bases:
ModuleGraph Attention Network to encode embeddings with a series of MHA layers consisting of a MHA layer, normalization, feed-forward layer, and normalization. Similar to Transformer encoder, as used in Kool et al. (2019).
- Parameters:
num_heads¶ (
int) – number of heads in the MHAembedding_dim¶ (
int) – dimension of the embeddingsnum_layers¶ (
int) – number of MHA layersnormalization¶ (
str) – type of normalization to use (batch, layer, none)feed_forward_hidden¶ (
int) – dimension of the hidden layer in the feed-forward layerforce_flash_attn¶ (
bool) – whether to force FlashAttention (move to half precision)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- class rl4co.models.nn.graph.attnnet.MultiHeadAttentionLayer(num_heads, embed_dim, feed_forward_hidden=512, normalization='batch', force_flash_attn=False)[source]¶
Bases:
SequentialMulti-Head Attention Layer with normalization and feed-forward layer
- Parameters:
num_heads¶ (
int) – number of heads in the MHAembed_dim¶ (
int) – dimension of the embeddingsfeed_forward_hidden¶ (
int) – dimension of the hidden layer in the feed-forward layernormalization¶ (
Optional[str]) – type of normalization to use (batch, layer, none)force_flash_attn¶ (
bool) – whether to force FlashAttention (move to half precision)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Graph Convolutional Encoder¶
- class rl4co.models.nn.graph.gcn.GCNEncoder(env_name, embedding_dim, num_nodes, num_layers, init_embedding=None, self_loop=False, residual=True)[source]¶
Bases:
ModuleGraph Convolutional Network to encode embeddings with a series of GCN layers
- Parameters:
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Message Passing Encoder¶
- class rl4co.models.nn.graph.mpnn.MessagePassingEncoder(env_name, embedding_dim, num_nodes, num_layers, init_embedding=None, aggregation='add', self_loop=False, residual=True)[source]¶
Bases:
ModuleNote
Support fully connected graph for now.
- forward(td, mask=None)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tuple[Tensor,Tensor]Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class rl4co.models.nn.graph.mpnn.MessagePassingLayer(node_indim, node_outdim, edge_indim, edge_outdim, aggregation='add', residual=False, **mlp_params)[source]¶
Bases:
MessagePassingInitializes internal Module state, shared by both nn.Module and ScriptModule.
- edge_update(nf, ef, edge_index)[source]¶
Computes or updates features for each edge in the graph. This function can take any argument as input which was initially passed to
edge_updater(). Furthermore, tensors passed toedge_updater()can be mapped to the respective nodes \(i\) and \(j\) by appending_ior_jto the variable name, .e.g.x_iandx_j.
- forward(node_feature, edge_feature, edge_index, mask=None)[source]¶
Runs the forward pass of the module.
- message(edge_features)[source]¶
Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in
edge_index. This function can take any argument as input which was initially passed topropagate(). Furthermore, tensors passed topropagate()can be mapped to the respective nodes \(i\) and \(j\) by appending_ior_jto the variable name, .e.g.x_iandx_j.
rl4co.models.nn.attention¶
- class rl4co.models.nn.attention.LogitAttention(embed_dim, num_heads, tanh_clipping=10.0, mask_inner=True, mask_logits=True, normalize=True, softmax_temp=1.0, linear_bias=False, force_flash_attn=False)[source]¶
Bases:
ModuleCalculate logits given query, key and value and logit key If we use Flash Attention, then we automatically move to fp16 for inner computations Note: with Flash Attention, masking is not supported
- Perform the following:
Apply cross attention to get the heads
Project heads to get glimpse
Compute attention score between glimpse and logit key
Normalize and mask
- Parameters:
embed_dim¶ (
int) – total dimension of the modelnum_heads¶ (
int) – number of headstanh_clipping¶ (
float) – tanh clipping valuemask_inner¶ (
bool) – whether to mask inner attentionmask_logits¶ (
bool) – whether to mask logitsnormalize¶ (
bool) – whether to normalize logitssoftmax_temp¶ (
float) – softmax temperatureforce_flash_attn¶ (
bool) – whether to force flash attention. If True, then we automatically cast to fp16
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- flash_attn_wrapper(func, *args, **kwargs)¶
Wrapper for flash attention to automatically cast to fp16 if needed
- forward(query, key, value, logit_key, mask, softmax_temp=None)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class rl4co.models.nn.attention.MultiHeadAttention(embed_dim, num_heads, bias=True, attention_dropout=0.0, causal=False, device=None, dtype=None, force_flash_attn=False)[source]¶
Bases:
ModulePyTorch native implementation of Flash Multi-Head Attention with automatic mixed precision support. Uses PyTorch’s native scaled_dot_product_attention implementation, available from 2.0
Note
If scaled_dot_product_attention is not available, use custom implementation of scaled_dot_product_attention without Flash Attention. In case you want to use Flash Attention, you may have a look at the MHA module under rl4co.models.nn.flash_attention.MHA.
- Parameters:
embed_dim¶ (
int) – total dimension of the modelnum_heads¶ (
int) – number of headsbias¶ (
bool) – whether to use biasattention_dropout¶ (
float) – dropout rate for attention weightscausal¶ (
bool) – whether to apply causal mask to attention scoresdevice¶ – torch device
dtype¶ – torch dtype
force_flash_attn¶ (
bool) – whether to force flash attention. If True, then we automatically cast to fp16
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- flash_attn_wrapper(func, *args, **kwargs)¶
Wrapper for flash attention to automatically cast to fp16 if needed
rl4co.models.nn.flash_attention¶
- class rl4co.models.nn.flash_attention.CrossAttention(causal=False, softmax_scale=None, attention_dropout=0.0, neg_inf=-10000.0)[source]¶
Bases:
ModuleImplement the scaled dot product attention with softmax. :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.CrossAttention.softmax_scale: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.CrossAttention.softmax_scale: (default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
- Parameters:
attention_dropout¶ (The dropout rate to apply to the attention) – (default: 0.0)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(q, kv, causal=None, key_padding_mask=None)[source]¶
Implements the multihead softmax attention. :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.CrossAttention.forward.q: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.CrossAttention.forward.q: :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.CrossAttention.forward.q: The tensor containing the query. (B, Sq, H, D) :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.CrossAttention.forward.kv: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.CrossAttention.forward.kv: :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.CrossAttention.forward.kv: The tensor containing the key and value. (B, Sk, 2, H, D) :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.CrossAttention.forward.causal: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.CrossAttention.forward.causal: :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.CrossAttention.forward.causal: if passed, will override self.causal :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.CrossAttention.forward.key_padding_mask: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.CrossAttention.forward.key_padding_mask: False means to mask out. (B, Sk) :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.CrossAttention.forward.key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
- class rl4co.models.nn.flash_attention.FlashCrossAttention(causal=False, softmax_scale=None, attention_dropout=0.0, triton=False)[source]¶
Bases:
ModuleImplement the scaled dot product attention with softmax. :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.softmax_scale: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.softmax_scale: (default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
- Parameters:
attention_dropout¶ (The dropout rate to apply to the attention) – (default: 0.0)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(q, kv, causal=None, cu_seqlens=None, max_seqlen=None, cu_seqlens_k=None, max_seqlen_k=None)[source]¶
Implements the multihead softmax attention. :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.q: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.q: :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.q: The tensor containing the query. (B, Sq, H, D) :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.kv: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.kv: :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.kv: The tensor containing the key and value. (B, Sk, 2, H, D) :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.causal: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.causal: :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.causal: if passed, will override self.causal :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.cu_seqlens: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.cu_seqlens: of the sequences in the batch, used to index into q. :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.max_seqlen: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.max_seqlen: :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.max_seqlen: int. Maximum sequence length in the batch of q. :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.cu_seqlens_k: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.cu_seqlens_k: of the sequences in the batch, used to index into kv. :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.max_seqlen_k: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.max_seqlen_k: :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.max_seqlen_k: int. Maximum sequence length in the batch of k and v.
- class rl4co.models.nn.flash_attention.FlashSelfAttention(causal=False, softmax_scale=None, attention_dropout=0.0, triton=False)[source]¶
Bases:
ModuleImplement the scaled dot product attention with softmax. :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashSelfAttention.softmax_scale: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashSelfAttention.softmax_scale: (default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
- Parameters:
attention_dropout¶ (The dropout rate to apply to the attention) – (default: 0.0)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(qkv, causal=None, cu_seqlens=None, max_seqlen=None)[source]¶
Implements the multihead softmax attention. :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashSelfAttention.forward.qkv: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashSelfAttention.forward.qkv: If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
If cu_seqlens is not None and max_seqlen is not None, then qkv has shape (total, 3, H, D), where total is the sum of the sequence lengths in the batch.
- Parameters:
Returns:¶
- :
- out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
else (B, S, H, D).
- class rl4co.models.nn.flash_attention.LinearResidual(in_features, out_features, bias=True, device=None, dtype=None)[source]¶
Bases:
LinearWrap nn.Linear to return the residual as well. For compatibility with FusedDense.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(input)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses. :rtype:
TensorNote
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class rl4co.models.nn.flash_attention.MHA(embed_dim, num_heads, cross_attn=False, bias=True, dropout=0.0, softmax_scale=None, causal=False, layer_idx=None, dwconv=False, rotary_emb_dim=0, rotary_emb_scale_base=0, fused_bias_fc=False, use_flash_attn=False, force_dtype_flash_attn=True, return_residual=False, device=None, dtype=None)[source]¶
Bases:
ModuleMulti-head self-attention and cross-attention
- return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us to fuse the backward of nn.Linear with the residual connection.
- flash_attn_wrapper(func, *args, **kwargs)¶
Wrapper for flash attention to automatically cast to fp16 if needed
- forward(x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None, mixer_subset=None, inference_params=None, **kwargs)[source]¶
- Parameters:
x¶ – (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total is the is the sum of the sequence lengths in the batch.
x_kv¶ – (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
cu_seqlens¶ – (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into x. Only applicable when using FlashAttention.
max_seqlen¶ – int. Maximum sequence length in the batch.
key_padding_mask¶ – boolean mask, True means to keep, False means to mask out. (batch, seqlen). Only applicable when not using FlashAttention.
mixer_subset¶ – for cross-attention only. If not None, will take a subset of x before applying the query projection. Useful for e.g., ViT where we only care about the CLS token in the last layer.
inference_params¶ – for generation. Adapted from Megatron-LM (and Apex)
https¶ – //github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
- class rl4co.models.nn.flash_attention.SelfAttention(causal=False, softmax_scale=None, attention_dropout=0.0, _inf=-10000.0)[source]¶
Bases:
ModuleImplement the scaled dot product attention with softmax. :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.SelfAttention.softmax_scale: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.SelfAttention.softmax_scale: (default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
- Parameters:
attention_dropout¶ (The dropout rate to apply to the attention) – (default: 0.0)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(qkv, causal=None, key_padding_mask=None)[source]¶
Implements the multihead softmax attention. :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.SelfAttention.forward.qkv: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.SelfAttention.forward.qkv: :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.SelfAttention.forward.qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.SelfAttention.forward.causal: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.SelfAttention.forward.causal: :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.SelfAttention.forward.causal: if passed, will override self.causal :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.SelfAttention.forward.key_padding_mask: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.SelfAttention.forward.key_padding_mask: False means to mask out. (B, S) :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.SelfAttention.forward.key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
rl4co.models.nn.mlp¶
- class rl4co.models.nn.mlp.MLP(input_dim, output_dim, num_neurons=[64, 32], hidden_act='ReLU', out_act='Identity', input_norm='None', output_norm='None')[source]¶
Bases:
ModuleInitializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(xs)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
rl4co.models.nn.ops¶
- class rl4co.models.nn.ops.Normalization(embed_dim, normalization='batch')[source]¶
Bases:
ModuleInitializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(x)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class rl4co.models.nn.ops.SkipConnection(module)[source]¶
Bases:
ModuleInitializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(x)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
rl4co.models.nn.utils¶
- class rl4co.models.nn.utils.RandomPolicy(env_name=None)[source]¶
Bases:
ModuleRandom Policy Class that randomly select actions from the action space This policy can be useful to check the sanity of the environment during development
We match the function signature of forward to the one of the AutoregressivePolicy class
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(td, env=None, max_steps=None)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- rl4co.models.nn.utils.decode_probs(probs, mask, decode_type='sampling')[source]¶
Decode probabilities to select actions with mask
- rl4co.models.nn.utils.get_log_likelihood(log_p, actions, mask, return_sum=True)[source]¶
Get log likelihood of selected actions