Source code for rl4co.models.nn.ops
import math
import torch.nn as nn
[docs]class SkipConnection(nn.Module):
def __init__(self, module):
super(SkipConnection, self).__init__()
self.module = module
[docs] def forward(self, x):
return x + self.module(x)
[docs]class Normalization(nn.Module):
def __init__(self, embed_dim, normalization="batch"):
super(Normalization, self).__init__()
normalizer_class = {"batch": nn.BatchNorm1d, "instance": nn.InstanceNorm1d}.get(
normalization, None
)
self.normalizer = normalizer_class(embed_dim, affine=True)
[docs] def init_parameters(self):
for name, param in self.named_parameters():
stdv = 1.0 / math.sqrt(param.size(-1))
param.data.uniform_(-stdv, stdv)
[docs] def forward(self, x):
if isinstance(self.normalizer, nn.BatchNorm1d):
return self.normalizer(x.view(-1, x.size(-1))).view(*x.size())
elif isinstance(self.normalizer, nn.InstanceNorm1d):
return self.normalizer(x.permute(0, 2, 1)).permute(0, 2, 1)
else:
assert self.normalizer is None, "Unknown normalizer type"
return x