pytorch Batch Normalization批标准化
来源:互联网 发布:linux 创建用户脚本 编辑:程序博客网 时间:2024/05/30 05:30
简书地址
import torchfrom torch.autograd import Variablefrom torch import nnfrom torch.nn import initimport torch.utils.data as Dataimport torch.nn.functional as Fimport matplotlib.pyplot as pltimport numpy as nptorch.manual_seed(1)np.random.seed(1)N_SAMPLES = 2000BATCH_SIZE = 64EPOCH = 12LR = 0.03N_HIDDEN = 8ACTIVATION = F.tanhB_INIT = -0.2# training datax = np.linspace(-7, 10, N_SAMPLES)[:, np.newaxis]noise = np.random.normal(0, 2, x.shape)y = np.square(x) - 5 + noise# test datatest_x = np.linspace(-7, 10, 200)[:, np.newaxis]noise = np.random.normal(0, 2, test_x.shape)test_y = np.square(test_x) - 5 + noisetrain_x = torch.from_numpy(x).float()train_y = torch.from_numpy(y).float()test_x = Variable( torch.from_numpy(test_x).float(), volatile = True,)test_y = Variable( torch.from_numpy(test_y).float(), volatile = True,)train_dataset = Data.TensorDataset( data_tensor = train_x, target_tensor = train_y,)train_loader = Data.DataLoader( dataset = train_dataset, batch_size = BATCH_SIZE, shuffle = True, num_workers = 2,)# show data# plt.scatter(train_x.numpy(),train_y.numpy(),c = '#FF9359',s = 50,alpha = 0.2,label = 'train',)# plt.legend(loc = 'upper left')class Net(nn.Module): def __init__(self, batch_normalization=False): super(Net, self).__init__() self.do_bn = batch_normalization self.fcs = [] self.bns = [] self.bn_input = nn.BatchNorm1d(1, momentum=0.5) for i in range(N_HIDDEN): input_size = 1 if i == 0 else 10 fc = nn.Linear(input_size, 10) setattr(self, 'fc%i' % i, fc) self._set_init(fc) self.fcs.append(fc) if self.do_bn: bn = nn.BatchNorm1d(10, momentum=0.5) setattr(self, 'bn%i' % i, bn) self.bns.append(bn) self.predict = nn.Linear(10, 1) self._set_init(self.predict) def _set_init(self, layer): init.normal(layer.weight, mean=0.,std=.1) init.constant(layer.bias, B_INIT) def forward(self, x): pre_activation = [x] if self.do_bn: x = self.bn_input(x) layer_input = [x] for i in range(N_HIDDEN): x = self.fcs[i](x) pre_activation.append(x) if self.do_bn: x = self.bns[i](x) x = ACTIVATION(x) layer_input.append(x) out = self.predict(x) return out, layer_input, pre_activationnets = [Net(batch_normalization=False), Net(batch_normalization=True)]print(*nets)opts = [torch.optim.Adam(net.parameters(), lr=LR) for net in nets]loss_func = torch.nn.MSELoss()f, axs = plt.subplots(4, N_HIDDEN+1, figsize=(10, 5))plt.ion() # something about plottingplt.show()def plot_histogram(l_in, l_in_bn, pre_ac, pre_ac_bn): for i, (ax_pa, ax_pa_bn, ax, ax_bn) in enumerate(zip(axs[0, :], axs[1, :], axs[2, :], axs[3, :])): [a.clear() for a in [ax_pa, ax_pa_bn, ax, ax_bn]] if i == 0: p_range = (-7, 10);the_range = (-7, 10) else:p_range = (-4, 4);the_range = (-1, 1) ax_pa.set_title('L' + str(i)) ax_pa.hist(pre_ac[i].data.numpy().ravel(), bins=10, range=p_range, color='#FF9359', alpha=0.5);ax_pa_bn.hist(pre_ac_bn[i].data.numpy().ravel(), bins=10, range=p_range, color='#74BCFF', alpha=0.5) ax.hist(l_in[i].data.numpy().ravel(), bins=10, range=the_range, color='#FF9359');ax_bn.hist(l_in_bn[i].data.numpy().ravel(), bins=10, range=the_range, color='#74BCFF') for a in [ax_pa, ax, ax_pa_bn, ax_bn]: a.set_yticks(());a.set_xticks(()) ax_pa_bn.set_xticks(p_range);ax_bn.set_xticks(the_range) axs[0, 0].set_ylabel('PreAct');axs[1, 0].set_ylabel('BN PreAct');axs[2, 0].set_ylabel('Act');axs[3, 0].set_ylabel('BN Act') plt.pause(0.01)# traininglosses = [[], []]for epoch in range(EPOCH): print('Epoch: ',epoch) layer_inputs, pre_acts = [], [] for net, l in zip(nets, losses): net.eval() # set eval mode to fix moving_mean and moving_var pred, layer_input, pre_act = net(test_x) l.append(loss_func(pred, test_y).data[0]) layer_inputs.append(layer_input) pre_acts.append(pre_act) net.train() # free moving_mean and moving_var plot_histogram(*layer_inputs, *pre_acts) # plot histogram for step, (b_x, b_y) in enumerate(train_loader): b_x, b_y = Variable(b_x), Variable(b_y) for net, opt in zip(nets, opts): pred, _, _ = net(b_x) loss = loss_func(pred, b_y) opt.zero_grad() loss.backward() opt.step()plt.ioff()# plot training lossplt.figure(2)plt.plot(losses[0], c='#FF9359', lw=3, label='Original')plt.plot(losses[1], c='#74BCFF', lw=3, label='Batch Normalization')plt.xlabel('step');plt.ylabel('test loss');plt.ylim((0, 2000));plt.legend(loc='best')# evaluation# set net to eval mode to freeze the parameters in batch normalization layers[net.eval() for net in nets] # set eval mode to fix moving_mean and moving_varpreds = [net(test_x)[0] for net in nets]plt.figure(3)plt.plot(test_x.data.numpy(), preds[0].data.numpy(), c='#FF9359', lw=4, label='Original')plt.plot(test_x.data.numpy(), preds[1].data.numpy(), c='#74BCFF', lw=4, label='Batch Normalization')plt.scatter(test_x.data.numpy(), test_y.data.numpy(), c='r', s=50, alpha=0.2, label='train')plt.legend(loc='best')plt.show()
阅读全文
0 0
- pytorch Batch Normalization批标准化
- 3.1 Tensorflow: 批标准化(Batch Normalization)
- 17批标准化(Batch Normalization )
- PyTorch笔记9-Batch Normalization
- 深度学习Deep Learning(05):Batch Normalization(BN)批标准化
- [pytorch] 利用batch normalization对Variable进行normalize/instance normalize
- Batch Normalization
- Batch Normalization
- Batch Normalization
- batch normalization
- Batch Normalization
- Batch Normalization
- batch normalization
- Batch Normalization
- Batch Normalization
- Batch Normalization
- Batch Normalization
- Batch Normalization
- Tomcat logs 目录下各日志文件的含义
- swift中的手势
- css应用表格排版td中的文本过长,设置不换行,随内容同行显示样式
- 红黑树及其操作
- C_线性表(ADT)-顺序表的表示和实现
- pytorch Batch Normalization批标准化
- HTMl5内置存储sessionStorage和localStorage
- Android关于Service服务
- Python list之append和extend的区别
- 日本咖啡馆推出AR体验,在三次元邂逅初音未来
- flume之Taildir Source支持变化追加文件的日志收集
- 可编辑div的使用,实现类似textarea功能
- android 6.0和7.0开发中应该注意的问题
- A + B Problem II