convert tf batch normalization to pytorch
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.module_list = nn.ModuleList()
module = nn.Sequential()
conv = nn.Conv2d(3, 32, 3, 1, 1, bias=False)
module.add_module('conv_0', conv)
bn = nn.BatchNorm2d(32)
module.add_module('batch_norm_0', bn)
gamma = np.random.rand(32)
gamma = torch.from_numpy(gamma)
bn.weight.data.copy_(gamma)
beta = np.random.rand(32)
beta = torch.from_numpy(beta)
bn.bias.data.copy_(beta)
mean = np.random.rand(32)
mean = torch.from_numpy(mean)
bn.running_mean.data.copy_(mean)
var = np.random.rand(32)
var = torch.from_numpy(var)
bn.running_var.data.copy_(var)
self.module_list.append(module)
def forward(self, input):
conv = self.module_list[0][0](input)
bn = self.module_list[0][1](conv)
return conv, bn
if __name__ == '__main__':
x = np.random.rand(1, 3, 64, 64)
x = Variable(torch.from_numpy(x).float())
model = Model()
model.eval()
with torch.no_grad():
conv_out, bn_out = model.forward(x)