Shortcuts

nn Modules

Graph Neural Networks

Graph Attention Encoder

class rl4co.models.nn.graph.attnnet.GraphAttentionNetwork(num_heads, embedding_dim, num_layers, normalization='batch', feed_forward_hidden=512, force_flash_attn=False)[source]

Bases: Module

Graph Attention Network to encode embeddings with a series of MHA layers consisting of a MHA layer, normalization, feed-forward layer, and normalization. Similar to Transformer encoder, as used in Kool et al. (2019).

Parameters:
  • num_heads (int) – number of heads in the MHA

  • embedding_dim (int) – dimension of the embeddings

  • num_layers (int) – number of MHA layers

  • normalization (str) – type of normalization to use (batch, layer, none)

  • feed_forward_hidden (int) – dimension of the hidden layer in the feed-forward layer

  • force_flash_attn (bool) – whether to force FlashAttention (move to half precision)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x, mask=None)[source]

Forward pass of the encoder

Parameters:
  • x (Tensor) – [batch_size, graph_size, embed_dim] initial embeddings to process

  • mask (Optional[Tensor]) – [batch_size, graph_size, graph_size] mask for the input embeddings. Unused for now.

Return type:

Tensor

class rl4co.models.nn.graph.attnnet.MultiHeadAttentionLayer(num_heads, embed_dim, feed_forward_hidden=512, normalization='batch', force_flash_attn=False)[source]

Bases: Sequential

Multi-Head Attention Layer with normalization and feed-forward layer

Parameters:
  • num_heads (int) – number of heads in the MHA

  • embed_dim (int) – dimension of the embeddings

  • feed_forward_hidden (int) – dimension of the hidden layer in the feed-forward layer

  • normalization (Optional[str]) – type of normalization to use (batch, layer, none)

  • force_flash_attn (bool) – whether to force FlashAttention (move to half precision)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Graph Convolutional Encoder

class rl4co.models.nn.graph.gcn.GCNEncoder(env_name, embedding_dim, num_nodes, num_layers, init_embedding=None, self_loop=False, residual=True)[source]

Bases: Module

Graph Convolutional Network to encode embeddings with a series of GCN layers

Parameters:
  • embedding_dim (int) – dimension of the embeddings

  • num_nodes (int) – number of nodes in the graph

  • num_gcn_layer – number of GCN layers

  • self_loop (bool) – whether to add self loop in the graph

  • residual (bool) – whether to use residual connection

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(td, mask=None)[source]

Forward pass of the encoder. Transform the input TensorDict into a latent representation.

Parameters:
  • td (TensorDict) – Input TensorDict containing the environment state

  • mask (Optional[Tensor]) – Mask to apply to the attention

Returns:

Latent representation of the input init_h: Initial embedding of the input

Return type:

h

Message Passing Encoder

class rl4co.models.nn.graph.mpnn.MessagePassingEncoder(env_name, embedding_dim, num_nodes, num_layers, init_embedding=None, aggregation='add', self_loop=False, residual=True)[source]

Bases: Module

Note

  • Support fully connected graph for now.

edge_update(nf, ef, edge_index)[source]
forward(td, mask=None)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tuple[Tensor, Tensor]

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

message(edge_features)[source]
update(aggr_msg, x)[source]
class rl4co.models.nn.graph.mpnn.MessagePassingLayer(node_indim, node_outdim, edge_indim, edge_outdim, aggregation='add', residual=False, **mlp_params)[source]

Bases: MessagePassing

Initializes internal Module state, shared by both nn.Module and ScriptModule.

edge_update(nf, ef, edge_index)[source]

Computes or updates features for each edge in the graph. This function can take any argument as input which was initially passed to edge_updater(). Furthermore, tensors passed to edge_updater() can be mapped to the respective nodes \(i\) and \(j\) by appending _i or _j to the variable name, .e.g. x_i and x_j.

forward(node_feature, edge_feature, edge_index, mask=None)[source]

Runs the forward pass of the module.

message(edge_features)[source]

Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in edge_index. This function can take any argument as input which was initially passed to propagate(). Furthermore, tensors passed to propagate() can be mapped to the respective nodes \(i\) and \(j\) by appending _i or _j to the variable name, .e.g. x_i and x_j.

update(aggr_msg, x)[source]

Updates node embeddings in analogy to \(\gamma_{\mathbf{\Theta}}\) for each node \(i \in \mathcal{V}\). Takes in the output of aggregation as first argument and any argument which was initially passed to propagate().

rl4co.models.nn.attention

class rl4co.models.nn.attention.LogitAttention(embed_dim, num_heads, tanh_clipping=10.0, mask_inner=True, mask_logits=True, normalize=True, softmax_temp=1.0, linear_bias=False, force_flash_attn=False)[source]

Bases: Module

Calculate logits given query, key and value and logit key If we use Flash Attention, then we automatically move to fp16 for inner computations Note: with Flash Attention, masking is not supported

Perform the following:
  1. Apply cross attention to get the heads

  2. Project heads to get glimpse

  3. Compute attention score between glimpse and logit key

  4. Normalize and mask

Parameters:
  • embed_dim (int) – total dimension of the model

  • num_heads (int) – number of heads

  • tanh_clipping (float) – tanh clipping value

  • mask_inner (bool) – whether to mask inner attention

  • mask_logits (bool) – whether to mask logits

  • normalize (bool) – whether to normalize logits

  • softmax_temp (float) – softmax temperature

  • force_flash_attn (bool) – whether to force flash attention. If True, then we automatically cast to fp16

Initializes internal Module state, shared by both nn.Module and ScriptModule.

flash_attn_wrapper(func, *args, **kwargs)

Wrapper for flash attention to automatically cast to fp16 if needed

forward(query, key, value, logit_key, mask, softmax_temp=None)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class rl4co.models.nn.attention.MultiHeadAttention(embed_dim, num_heads, bias=True, attention_dropout=0.0, causal=False, device=None, dtype=None, force_flash_attn=False)[source]

Bases: Module

PyTorch native implementation of Flash Multi-Head Attention with automatic mixed precision support. Uses PyTorch’s native scaled_dot_product_attention implementation, available from 2.0

Note

If scaled_dot_product_attention is not available, use custom implementation of scaled_dot_product_attention without Flash Attention. In case you want to use Flash Attention, you may have a look at the MHA module under rl4co.models.nn.flash_attention.MHA.

Parameters:
  • embed_dim (int) – total dimension of the model

  • num_heads (int) – number of heads

  • bias (bool) – whether to use bias

  • attention_dropout (float) – dropout rate for attention weights

  • causal (bool) – whether to apply causal mask to attention scores

  • device – torch device

  • dtype – torch dtype

  • force_flash_attn (bool) – whether to force flash attention. If True, then we automatically cast to fp16

Initializes internal Module state, shared by both nn.Module and ScriptModule.

flash_attn_wrapper(func, *args, **kwargs)

Wrapper for flash attention to automatically cast to fp16 if needed

forward(x, key_padding_mask=None)[source]

x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) key_padding_mask: bool tensor of shape (batch, seqlen)

rl4co.models.nn.attention.flash_attn_wrapper(self, func, *args, **kwargs)[source]

Wrapper for flash attention to automatically cast to fp16 if needed

rl4co.models.nn.flash_attention

class rl4co.models.nn.flash_attention.CrossAttention(causal=False, softmax_scale=None, attention_dropout=0.0, neg_inf=-10000.0)[source]

Bases: Module

Implement the scaled dot product attention with softmax. :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.CrossAttention.softmax_scale: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.CrossAttention.softmax_scale: (default: 1/sqrt(d_keys) where d_keys is computed at

runtime)

Parameters:

attention_dropout (The dropout rate to apply to the attention) – (default: 0.0)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(q, kv, causal=None, key_padding_mask=None)[source]

Implements the multihead softmax attention. :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.CrossAttention.forward.q: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.CrossAttention.forward.q: :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.CrossAttention.forward.q: The tensor containing the query. (B, Sq, H, D) :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.CrossAttention.forward.kv: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.CrossAttention.forward.kv: :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.CrossAttention.forward.kv: The tensor containing the key and value. (B, Sk, 2, H, D) :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.CrossAttention.forward.causal: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.CrossAttention.forward.causal: :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.CrossAttention.forward.causal: if passed, will override self.causal :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.CrossAttention.forward.key_padding_mask: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.CrossAttention.forward.key_padding_mask: False means to mask out. (B, Sk) :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.CrossAttention.forward.key_padding_mask: boolean mask to apply to the attention weights. True means to keep,

class rl4co.models.nn.flash_attention.FlashCrossAttention(causal=False, softmax_scale=None, attention_dropout=0.0, triton=False)[source]

Bases: Module

Implement the scaled dot product attention with softmax. :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.softmax_scale: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.softmax_scale: (default: 1/sqrt(d_keys) where d_keys is computed at

runtime)

Parameters:

attention_dropout (The dropout rate to apply to the attention) – (default: 0.0)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(q, kv, causal=None, cu_seqlens=None, max_seqlen=None, cu_seqlens_k=None, max_seqlen_k=None)[source]

Implements the multihead softmax attention. :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.q: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.q: :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.q: The tensor containing the query. (B, Sq, H, D) :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.kv: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.kv: :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.kv: The tensor containing the key and value. (B, Sk, 2, H, D) :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.causal: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.causal: :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.causal: if passed, will override self.causal :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.cu_seqlens: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.cu_seqlens: of the sequences in the batch, used to index into q. :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.max_seqlen: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.max_seqlen: :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.max_seqlen: int. Maximum sequence length in the batch of q. :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.cu_seqlens_k: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.cu_seqlens_k: of the sequences in the batch, used to index into kv. :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.max_seqlen_k: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.max_seqlen_k: :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashCrossAttention.forward.max_seqlen_k: int. Maximum sequence length in the batch of k and v.

class rl4co.models.nn.flash_attention.FlashSelfAttention(causal=False, softmax_scale=None, attention_dropout=0.0, triton=False)[source]

Bases: Module

Implement the scaled dot product attention with softmax. :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashSelfAttention.softmax_scale: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashSelfAttention.softmax_scale: (default: 1/sqrt(d_keys) where d_keys is computed at

runtime)

Parameters:

attention_dropout (The dropout rate to apply to the attention) – (default: 0.0)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(qkv, causal=None, cu_seqlens=None, max_seqlen=None)[source]

Implements the multihead softmax attention. :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashSelfAttention.forward.qkv: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.FlashSelfAttention.forward.qkv: If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).

If cu_seqlens is not None and max_seqlen is not None, then qkv has shape (total, 3, H, D), where total is the sum of the sequence lengths in the batch.

Parameters:
  • causal (if passed, will override self.causal) –

  • cu_seqlens ((batch_size + 1,), dtype torch.int32. The cumulative sequence lengths) – of the sequences in the batch, used to index into qkv.

  • max_seqlen (int. Maximum sequence length in the batch.) –

Returns:

:
out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,

else (B, S, H, D).

class rl4co.models.nn.flash_attention.LinearResidual(in_features, out_features, bias=True, device=None, dtype=None)[source]

Bases: Linear

Wrap nn.Linear to return the residual as well. For compatibility with FusedDense.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(input)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class rl4co.models.nn.flash_attention.MHA(embed_dim, num_heads, cross_attn=False, bias=True, dropout=0.0, softmax_scale=None, causal=False, layer_idx=None, dwconv=False, rotary_emb_dim=0, rotary_emb_scale_base=0, fused_bias_fc=False, use_flash_attn=False, force_dtype_flash_attn=True, return_residual=False, device=None, dtype=None)[source]

Bases: Module

Multi-head self-attention and cross-attention

return_residual: whether to return the input x along with the output. This is for

performance reason: for post-norm architecture, returning the input allows us to fuse the backward of nn.Linear with the residual connection.

flash_attn_wrapper(func, *args, **kwargs)

Wrapper for flash attention to automatically cast to fp16 if needed

forward(x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None, mixer_subset=None, inference_params=None, **kwargs)[source]
Parameters:
  • x – (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total is the is the sum of the sequence lengths in the batch.

  • x_kv – (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.

  • cu_seqlens – (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into x. Only applicable when using FlashAttention.

  • max_seqlen – int. Maximum sequence length in the batch.

  • key_padding_mask – boolean mask, True means to keep, False means to mask out. (batch, seqlen). Only applicable when not using FlashAttention.

  • mixer_subset – for cross-attention only. If not None, will take a subset of x before applying the query projection. Useful for e.g., ViT where we only care about the CLS token in the last layer.

  • inference_params – for generation. Adapted from Megatron-LM (and Apex)

  • https – //github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470

class rl4co.models.nn.flash_attention.SelfAttention(causal=False, softmax_scale=None, attention_dropout=0.0, _inf=-10000.0)[source]

Bases: Module

Implement the scaled dot product attention with softmax. :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.SelfAttention.softmax_scale: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.SelfAttention.softmax_scale: (default: 1/sqrt(d_keys) where d_keys is computed at

runtime)

Parameters:

attention_dropout (The dropout rate to apply to the attention) – (default: 0.0)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(qkv, causal=None, key_padding_mask=None)[source]

Implements the multihead softmax attention. :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.SelfAttention.forward.qkv: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.SelfAttention.forward.qkv: :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.SelfAttention.forward.qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.SelfAttention.forward.causal: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.SelfAttention.forward.causal: :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.SelfAttention.forward.causal: if passed, will override self.causal :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.SelfAttention.forward.key_padding_mask: :param _sphinx_paramlinks_rl4co.models.nn.flash_attention.SelfAttention.forward.key_padding_mask: False means to mask out. (B, S) :type _sphinx_paramlinks_rl4co.models.nn.flash_attention.SelfAttention.forward.key_padding_mask: boolean mask to apply to the attention weights. True means to keep,

rl4co.models.nn.flash_attention.flash_attn_wrapper(self, func, *args, **kwargs)[source]

Wrapper for flash attention to automatically cast to fp16 if needed

rl4co.models.nn.mlp

class rl4co.models.nn.mlp.MLP(input_dim, output_dim, num_neurons=[64, 32], hidden_act='ReLU', out_act='Identity', input_norm='None', output_norm='None')[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(xs)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

rl4co.models.nn.ops

class rl4co.models.nn.ops.Normalization(embed_dim, normalization='batch')[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

init_parameters()[source]
class rl4co.models.nn.ops.SkipConnection(module)[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

rl4co.models.nn.utils

class rl4co.models.nn.utils.RandomPolicy(env_name=None)[source]

Bases: Module

Random Policy Class that randomly select actions from the action space This policy can be useful to check the sanity of the environment during development

We match the function signature of forward to the one of the AutoregressivePolicy class

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(td, env=None, max_steps=None)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

rl4co.models.nn.utils.decode_probs(probs, mask, decode_type='sampling')[source]

Decode probabilities to select actions with mask

rl4co.models.nn.utils.get_log_likelihood(log_p, actions, mask, return_sum=True)[source]

Get log likelihood of selected actions

rl4co.models.nn.utils.random_policy(td)[source]

Helper function to select a random action from available actions

rl4co.models.nn.utils.rollout(env, td, policy, max_steps=None)[source]

Helper function to rollout a policy. Currently, TorchRL does not allow to step over envs when done with env.rollout(). We need this because for environments that complete at different steps.