您的位置:首页 > 其它

[pytorch] 利用batch normalization对Variable进行normalize/instance normalize

2017-04-01 18:14 375 查看
天啦噜!!我发现更新的pytorch已经有instance normalization了!!

不用自己折腾了!!

-2017.5.25

利用 nn.Module 里的 子类 _BatchNorm (在torch.nn.modules.batchnorm中定义),可以实现各种需求的normalize。



在docs里,可以看到,有3种normalization layer,但其实他们都是继承了_BatchNorm这个类的,所以我们看看BatchNorm2d,就可以对其他的方法举一反三啦~

先来看看文档



不清楚没关系,接下来用例子讲解:

创建一个BatchNorm2d的实例的方法如下

import torch.nn as nn
norm = nn.BatchNorm2d(fea_num, affine=False).cuda()


其中,fea_num 是拉出来的维度,就是说按照 fea_num 的维度,其他维度拉成一长条来normalize,fea_num对应input的第1个(维度从0开始计)维度, 所以两者的值应相等。.cuda()是把这个module放到gpu上。

在普通的batch normalize的情况下

input是(batchsize,channel,height,width)=(4,3,5,5)来看,fea_num对应channel。所以channel=0时,求一次mean,var,做一次normalize;channel=1时,求一次。。channel=2时,求一次。。

在训练中,还有两个可以学习的参数gamma & beta,所以在gamma & beta设定为可变参数的情况下,应该这样创建和使用batchnorm layer:

#input is cuda float Variable of batchsize x channel x height x width
#train state
norm = nn.BatchNorm2d(channel).cuda()#默认affine=True
input = norm(input)


注意:

在train之前正确的初始化可变参数

在test/eval 模式下,应该用.eval() 固定住可变参数。

一个input的测试例子:

import numpy as np
from torch.autograd import Variable
BS = 2
C = 3
H = 2
W = 2
input = np.arange(BS*C*H*W)
input.resize(BS, C, H, W)
input = Variable(torch.from_numpy(input).clone().float()).cuda().contiguous()


如果不需要可变参数 gamma & beta,那直接:

#input is cuda float Variable of batchsize x channel x height x width
norm = nn.BatchNorm2d(channel, affine=False).cuda()
input = norm(input)


其他情况的normalize,如instance normalize

input还是(batchsize,channel,height,width)=(4,3,5,5)假设我们想把batchsize这一个维度拉出来,对每一个instance(batchsize=0~3)看做(3,5,5)的3D tensor 求一次normalize,那怎么做呢?其实很简单,把input的第0维和第1维调换一下就好了。

#input is cuda float Variable of batchsize x channel x height x width
instanceNorm = nn.BatchNorm2d(BS, affine=False).cuda()
input = input .transpose(0,1).contiguous()
input = instanceNorm(input)
input = input .transpose(0,1).contiguous()


注意:

affine参数看需求设定,注意事项同普通batch normalize情况

如果没使用.contiguous(),很有可能报错

RuntimeError: Assertion `THCTensor_(isContiguous)(state, t)' failed.  at **/pytorch/torch/lib/THCUNN/generic/BatchNormalization.cu:20


总而言之,记得BatchNorm layer 的 fea_num的取值=input拉出来的那个维度的大小,且该维度应该是input的第1维,如果不是,用resize、transpose、unsqueeze啥的搞到是就好了
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  pytorch batch-norm