Shortcuts

Source code for rl4co.models.zoo.ham.model

from typing import Union

from rl4co.envs.common.base import RL4COEnvBase
from rl4co.models.rl import REINFORCE
from rl4co.models.rl.reinforce.baselines import REINFORCEBaseline
from rl4co.models.zoo.ham.policy import HeterogeneousAttentionModelPolicy


[docs]class HeterogeneousAttentionModel(REINFORCE): """Heterogenous Attention Model for solving the Pickup and Delivery Problem based on REINFORCE: https://arxiv.org/abs/2110.02634. Args: env: Environment to use for the algorithm policy: Policy to use for the algorithm baseline: REINFORCE baseline. Defaults to rollout (1 epoch of exponential, then greedy rollout baseline) policy_kwargs: Keyword arguments for policy baseline_kwargs: Keyword arguments for baseline **kwargs: Keyword arguments passed to the superclass """ def __init__( self, env: RL4COEnvBase, policy: HeterogeneousAttentionModelPolicy = None, baseline: Union[REINFORCEBaseline, str] = "rollout", policy_kwargs={}, baseline_kwargs={}, **kwargs, ): assert ( env.name == "pdp" ), "HeterogeneousAttentionModel only works for PDP (Pickup and Delivery Problem)" if policy is None: policy = HeterogeneousAttentionModelPolicy(env.name, **policy_kwargs) super().__init__(env, policy, baseline, baseline_kwargs, **kwargs)

© Copyright Federico Berto, Chuanbo Hua, Junyoung Park. Revision 14d072ed.

Built with Sphinx using a theme provided by Read the Docs.