您的位置:首页 > Web前端

Caffe solver.net.forward(),solver.test_nets[0].forward() 和 solver.step(1)

2017-12-27 23:35 471 查看

前言

Caffe 代码中的 solver.net.forward() , solver.test_nets[0].forward() 和 solver.step(1) 区别和作用。

正文

三个函数都是将批量大小(batch_size)的图片送到网络, solver.net.forward() 和 solver.test_nets[0].forward() 是将batch_size个图片送到网络中去,只有前向传播(Forward Propagation,BP),solver.net.forward()作用于训练集,solver.test_nets[0].forward() 作用于测试集,一般用于获得测试集的正确率。solver.step(1) 也是将batch_size个图片送到网络中去,不过 solver.step(1) 不仅有FP,而且还有反向传播(Back Propagation,BP)!这样就可以更新整个网络的权值(weights),同时得到该batch的loss。

让我们用代码实例来验证一下上述函数的作用。

下面的例子来自我的博客 使用 Caffe Python 编写 LeNetB1 ,在 B1 中,我定义了网络,我们首先加载这个网路。该网络的训练集batch_size=64,测试集batch_size=100。为获得更高的效率,我在Jupyter Notebook 中实现如下代码。

为了方便,我先给出mnist训练集前192个数字和测试集前200个数字。



代码1:

from pylab import *
import caffe
%matplotlib inline


caffe.set_device(0)
caffe.set_mode_gpu()    #设置GPU,不是GPU环境的使用caffe.set_mode_cpu()

solver = None
solver =    caffe.SGDSolver('C:/Users/Admin512/Desktop/MyStudy/caffe_python/LeNet/mnist/lenet_auto_solver.prototxt')


在运行网络之前,我们先看一下网络的层次:

代码2:

[(k, v.data.shape) for k, v in solver.net.blobs.items()]


输出2:

[('data', (64L, 1L, 28L, 28L)),
('label', (64L,)),
('conv1', (64L, 20L, 24L, 24L)),
('pool1', (64L, 20L, 12L, 12L)),
('conv2', (64L, 50L, 8L, 8L)),
('pool2', (64L, 50L, 4L, 4L)),
('fc1', (64L, 500L)),
('score', (64L, 10L)),
('loss', ())]


看一下第一层“data”中batch图像中的第一个图像是哪个数字?

代码3:

A = solver.net.blobs['data']
print(A.data.shape)
A.data[0,0]
imshow(A.data[0,0], cmap='gray');


输出3:



黑乎乎的一片,并没有显示数字,说明数据没有传入到网络中。

下面我们运行:solver.net.forward() :

代码4:

solver.net.forward()  # train net


再次执行代码3:输出如下:



数据已经传进来了,很明显这是训练集第一个图。

我们再次运行代码4,然后运行代码3,输出如下:



是不是第65(索引为64)个?是!说明执行solver.net.forward() 后,训练集数据按照batch_size=64输入带网络的。

我们再来看 solver.test_nets[0].forward() :

代码5:

solver.test_nets[0].forward()

# 显示test_net传入的数据
B = solver.test_nets[0].blobs['data'].data
imshow(B[0,0], cmap='gray');


输出5:



是测试集第1个图像。

猜想再次运行代码5,应该显示的是第101个图像(索引位100),我们运行一下:



猜想正确!

现在知道 solver.net.forward() 和 solver.test_nets[0].forward() 的作用了,下面我们看看 solver.step(1) 的作用。

首先初始化整个网络,就是再次运行代码1。

代码6:

solver.step(1)
A = solver.net.blobs['data'] print(A.data.shape) A.data[0,0] imshow(A.data[0,0], cmap='gray');


两次运行代码6,发现两次出现的图像和两次运行solver.net.forward() 的效果一样,说明 代码6 执行的是FP+DP。

完。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  caffe python mnist