batchnorm1d pytorch
class network(nn.Module):
def __init__(self):
super(network, self).__init__()
self.linear1 = nn.Linear(in_features=40, out_features=320)
self.bn1 = nn.BatchNorm1d(num_features=320)
self.linear2 = nn.Linear(in_features=320, out_features=2)
def forward(self, input): # Input is a 1D tensor
y = F.relu(self.bn1(self.linear1(input)))
y = F.softmax(self.linear2(y), dim=1)
return y
model = network()
x = torch.randn(10, 40)
output = model(x)