Shortcuts

Source code for rl4co.models.zoo.mdam.policy

from typing import Union

from tensordict import TensorDict

from rl4co.envs import RL4COEnvBase, get_env
from rl4co.models.nn.env_embeddings import env_init_embedding
from rl4co.models.zoo.common.autoregressive import AutoregressivePolicy
from rl4co.models.zoo.mdam.decoder import Decoder
from rl4co.models.zoo.mdam.encoder import GraphAttentionEncoder
from rl4co.utils.pylogger import get_pylogger

log = get_pylogger(__name__)


[docs]class MDAMPolicy(AutoregressivePolicy): """Multi-Decoder Attention Model (MDAM) policy. Args: """ def __init__( self, env_name: str, embedding_dim: int = 128, num_encoder_layers: int = 3, num_heads: int = 8, normalization: str = "batch", **kwargs, ): super(MDAMPolicy, self).__init__( env_name=env_name, encoder=GraphAttentionEncoder( num_heads=num_heads, embed_dim=embedding_dim, num_layers=num_encoder_layers, normalization=normalization, **kwargs, ), decoder=Decoder( env_name=env_name, embedding_dim=embedding_dim, num_heads=num_heads, **kwargs, ), embedding_dim=embedding_dim, num_encoder_layers=num_encoder_layers, num_heads=num_heads, normalization=normalization, **kwargs, ) self.init_embedding = env_init_embedding( env_name, {"embedding_dim": embedding_dim} )
[docs] def forward( self, td: TensorDict, env: Union[str, RL4COEnvBase] = None, phase: str = "train", return_actions: bool = False, **decoder_kwargs, ) -> TensorDict: embedding = self.init_embedding(td) encoded_inputs, _, attn, V, h_old = self.encoder(embedding) # Instantiate environment if needed if isinstance(env, str) or env is None: env_name = self.env_name if env is None else env log.info(f"Instantiated environment not provided; instantiating {env_name}") env = get_env(env_name) # Get decode type depending on phase if decoder_kwargs.get("decode_type", None) is None: decoder_kwargs["decode_type"] = getattr(self, f"{phase}_decode_type") reward, log_likelihood, kl_divergence, actions = self.decoder( td, encoded_inputs, env, attn, V, h_old, **decoder_kwargs ) out = { "reward": reward, "log_likelihood": log_likelihood, "entropy": kl_divergence, "actions": actions if return_actions else None, } return out

© Copyright Federico Berto, Chuanbo Hua, Junyoung Park. Revision 14d072ed.

Built with Sphinx using a theme provided by Read the Docs.