CS231n 2016 通关 第五、六章 Batch Normalization 作业
2016-06-07 21:40
357 查看
BN层在实际中应用广泛。
上一次总结了使得训练变得简单的方法,比如SGD+momentum RMSProp Adam,BN是另外的方法。
cell 1 依旧是初始化设置
cell 2 读取cifar-10数据
cell 3 BN的前传
相应的核心代码:
running_mean running_var 是在test时使用的,test时不再另外计算均值和方差。
test 时的前传核心代码:
cell 5 BN反向传播
通过反向传播,计算beta gamma等参数。
核心代码:
cell 9 BN与其他层结合
形成的结构: {affine - [batch norm] - relu - [dropout]} x (L - 1) - affine - softmax
原理依旧。
之后是对cell 9 的模型,对cifar-10数据训练。
值得注意的是:
使用BN后,正则项与dropout层的需求降低。可以使用较高的学习率加快模型收敛。
附:通关CS231n企鹅群:578975100 validation:DL-CS231n
上一次总结了使得训练变得简单的方法,比如SGD+momentum RMSProp Adam,BN是另外的方法。
cell 1 依旧是初始化设置
cell 2 读取cifar-10数据
cell 3 BN的前传
# Check the training-time forward pass by checking means and variances # of features both before and after batch normalization # Simulate the forward pass for a two-layer network N, D1, D2, D3 = 200, 50, 60, 3 X = np.random.randn(N, D1) W1 = np.random.randn(D1, D2) W2 = np.random.randn(D2, D3) a = np.maximum(0, X.dot(W1)).dot(W2) print 'Before batch normalization:' print ' means: ', a.mean(axis=0) print ' stds: ', a.std(axis=0) # Means should be close to zero and stds close to one print 'After batch normalization (gamma=1, beta=0)' a_norm, _ = batchnorm_forward(a, np.ones(D3), np.zeros(D3), {'mode': 'train'}) print ' mean: ', a_norm.mean(axis=0) print ' std: ', a_norm.std(axis=0) # Now means should be close to beta and stds close to gamma gamma = np.asarray([1.0, 2.0, 3.0]) beta = np.asarray([11.0, 12.0, 13.0]) a_norm, _ = batchnorm_forward(a, gamma, beta, {'mode': 'train'}) print 'After batch normalization (nontrivial gamma, beta)' print ' means: ', a_norm.mean(axis=0) print ' stds: ', a_norm.std(axis=0)
相应的核心代码:
buf_mean = np.mean(x, axis=0) buf_var = np.var(x, axis=0) x_hat = x - buf_mean x_hat = x_hat / (np.sqrt(buf_var + eps)) out = gamma * x_hat + beta #running_mean = momentum * running_mean + (1 - momentum) * sample_mean #running_var = momentum * running_var + (1 - momentum) * sample_var running_mean = momentum * running_mean + (1- momentum) * buf_mean running_var = momentum * running_var + (1 - momentum) * buf_var
running_mean running_var 是在test时使用的,test时不再另外计算均值和方差。
test 时的前传核心代码:
x_hat = x - running_mean x_hat = x_hat / (np.sqrt(running_var + eps)) out = gamma * x_hat + beta
cell 5 BN反向传播
通过反向传播,计算beta gamma等参数。
核心代码:
dx_hat = dout * cache['gamma'] dgamma = np.sum(dout * cache['x_hat'], axis=0) dbeta = np.sum(dout, axis=0) #x_hat = x - buf_mean #x_hat = x_hat / (np.sqrt(buf_var + eps)) t1 = cache['x'] - cache['mean'] t2 = (-0.5)*((cache['var'] + cache['eps'])**(-1.5)) t1 = t1 * t2 d_var = np.sum(dx_hat * t1, axis=0) tmean1 = (-1)*((cache['var'] + cache['eps'])**(-0.5)) d_mean = np.sum(dx_hat * tmean1, axis=0) tmean1 = (-1)*tmean1 tx1 = dx_hat * tmean1 tx2 = d_mean * (1.0 / float(N)) tx3 = d_var * (2 * (cache['x'] - cache['mean']) / N) dx = tx1 + tx2 + tx3
cell 9 BN与其他层结合
形成的结构: {affine - [batch norm] - relu - [dropout]} x (L - 1) - affine - softmax
原理依旧。
之后是对cell 9 的模型,对cifar-10数据训练。
值得注意的是:
使用BN后,正则项与dropout层的需求降低。可以使用较高的学习率加快模型收敛。
附:通关CS231n企鹅群:578975100 validation:DL-CS231n
相关文章推荐
- emoji
- Leetcode Binary Tree Upside Down
- C/C++ 中缀表达式转换成后缀表达式并求值
- 多线程和多进程的区别(小结)
- Maven内置变量说明:
- telnet、ssh配置信息
- jquery与php交互之GET、 POST
- Struts2源码分析——StrutsPrepareAndExecuteFilter
- javaweb学习总结(四十五)——监听器(Listener)学习二
- 旅游队参加省赛的总结
- gdb调试多进程和多线程程序
- java读取文件 每行首字丢失问题
- 凹数科技笔试
- HDU1002(高精度计算)
- 运维工程师到底在作什么?从何学起,掌握哪些知识?
- 内联函数和宏定义的区别
- java中util.Date和数据库中datetime的操作!
- 汉诺塔问题小结
- 第二次冲刺第十天
- iOS 开发技巧收藏贴 链接整理