Shortcuts

Source code for rl4co.models.nn.ops

import math

import torch.nn as nn


[docs]class SkipConnection(nn.Module): def __init__(self, module): super(SkipConnection, self).__init__() self.module = module
[docs] def forward(self, x): return x + self.module(x)
[docs]class Normalization(nn.Module): def __init__(self, embed_dim, normalization="batch"): super(Normalization, self).__init__() normalizer_class = {"batch": nn.BatchNorm1d, "instance": nn.InstanceNorm1d}.get( normalization, None ) self.normalizer = normalizer_class(embed_dim, affine=True)
[docs] def init_parameters(self): for name, param in self.named_parameters(): stdv = 1.0 / math.sqrt(param.size(-1)) param.data.uniform_(-stdv, stdv)
[docs] def forward(self, x): if isinstance(self.normalizer, nn.BatchNorm1d): return self.normalizer(x.view(-1, x.size(-1))).view(*x.size()) elif isinstance(self.normalizer, nn.InstanceNorm1d): return self.normalizer(x.permute(0, 2, 1)).permute(0, 2, 1) else: assert self.normalizer is None, "Unknown normalizer type" return x

© Copyright Federico Berto, Chuanbo Hua, Junyoung Park. Revision 14d072ed.

Built with Sphinx using a theme provided by Read the Docs.