Shortcuts

Source code for rl4co.models.nn.mlp

from typing import List

import torch.nn as nn


[docs]class MLP(nn.Module): def __init__( self, input_dim: int, output_dim: int, num_neurons: List[int] = [64, 32], hidden_act: str = "ReLU", out_act: str = "Identity", input_norm: str = "None", output_norm: str = "None", ): super(MLP, self).__init__() assert input_norm in ["Batch", "Layer", "None"] assert output_norm in ["Batch", "Layer", "None"] self.input_dim = input_dim self.output_dim = output_dim self.num_neurons = num_neurons self.hidden_act = getattr(nn, hidden_act)() self.out_act = getattr(nn, out_act)() input_dims = [input_dim] + num_neurons output_dims = num_neurons + [output_dim] self.lins = nn.ModuleList() for i, (in_dim, out_dim) in enumerate(zip(input_dims, output_dims)): self.lins.append(nn.Linear(in_dim, out_dim)) self.input_norm = self._get_norm_layer(input_norm, input_dim) self.output_norm = self._get_norm_layer(output_norm, output_dim)
[docs] def forward(self, xs): xs = self.input_norm(xs) for i, lin in enumerate(self.lins[:-1]): xs = lin(xs) xs = self.hidden_act(xs) xs = self.lins[-1](xs) xs = self.out_act(xs) xs = self.output_norm(xs) return xs
@staticmethod def _get_norm_layer(norm_method, dim): if norm_method == "Batch": in_norm = nn.BatchNorm1d(dim) elif norm_method == "Layer": in_norm = nn.LayerNorm(dim) elif norm_method == "None": in_norm = nn.Identity() # kinda placeholder else: raise RuntimeError( "Not implemented normalization layer type {}".format(norm_method) ) return in_norm def _get_act(self, is_last): return self.out_act if is_last else self.hidden_act

© Copyright Federico Berto, Chuanbo Hua, Junyoung Park. Revision f4bc96ca.

Built with Sphinx using a theme provided by Read the Docs.