Shortcuts

Source code for rl4co.models.zoo.matnet.decoder

from dataclasses import dataclass
from typing import Tuple, Union

from tensordict import TensorDict
from torch import Tensor

from rl4co.models.zoo.common.autoregressive.decoder import AutoregressiveDecoder


[docs]@dataclass class PrecomputedCache: node_embeddings: Tensor graph_context: Union[Tensor, float] glimpse_key: Tensor glimpse_val: Tensor logit_key: Tensor
[docs]class MatNetDecoder(AutoregressiveDecoder): def _precompute_cache(self, embeddings: Tuple[Tensor, Tensor], td: TensorDict = None): col_emb, row_emb = embeddings ( glimpse_key_fixed, glimpse_val_fixed, logit_key, ) = self.project_node_embeddings( col_emb ).chunk(3, dim=-1) # Optionally disable the graph context from the initial embedding as done in POMO if self.use_graph_context: graph_context = self.project_fixed_context(col_emb.mean(1)) else: graph_context = 0 # Organize in a dataclass for easy access return PrecomputedCache( node_embeddings=row_emb, graph_context=graph_context, glimpse_key=glimpse_key_fixed, glimpse_val=glimpse_val_fixed, # logit_key=col_emb, logit_key=logit_key, )

© Copyright Federico Berto, Chuanbo Hua, Junyoung Park. Revision 14d072ed.

Built with Sphinx using a theme provided by Read the Docs.