Shortcuts

Source code for rl4co.models.zoo.matnet.model

from typing import Any, Union
from rl4co.models.zoo.matnet.policy import MatNetPolicy

import torch.nn as nn

from rl4co.models.zoo.pomo.model import POMO
from rl4co.envs.common.base import RL4COEnvBase


[docs]class MatNet(POMO): def __init__( self, env: RL4COEnvBase, policy: Union[nn.Module, MatNetPolicy] = None, optimizer_kwargs: dict = {"lr": 4 * 1e-4, "weight_decay": 1e-6}, lr_scheduler: str = "MultiStepLR", lr_scheduler_kwargs: dict = {"milestones": [2001, 2101], "gamma": 0.1}, use_dihedral_8: bool = False, num_starts: int = None, train_data_size: int = 10_000, batch_size: int = 200, policy_params: dict = {}, model_params: dict = {}, ): if policy is None: policy = MatNetPolicy(env_name=env.name, **policy_params) super(MatNet, self).__init__( env=env, policy=policy, optimizer_kwargs=optimizer_kwargs, lr_scheduler=lr_scheduler, lr_scheduler_kwargs=lr_scheduler_kwargs, use_dihedral_8=use_dihedral_8, num_starts=num_starts, train_data_size=train_data_size, batch_size=batch_size, **model_params, )

© Copyright Federico Berto, Chuanbo Hua, Junyoung Park. Revision 14d072ed.

Built with Sphinx using a theme provided by Read the Docs.