您的位置:首页 > Web前端

Windows caffe (二) cifar10 demo 训练与测试

2017-03-09 16:30 393 查看

1、数据集的获取

首先需要安装Git和Wget,方法请参考上一篇博客
执行根目录data/cifar10目录下的get_cifar.sh,cifar内容如下:
#!/usr/bin/env sh
# This scripts downloads the CIFAR10 (binary version) data and unzips it.

DIR="$( cd "$(dirname "$0")" ; pwd -P )"
cd "$DIR"

echo "Downloading..."

wget --no-check-certificate http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz 
echo "Unzipping..."

tar -xf cifar-10-binary.tar.gz && rm -f cifar-10-binary.tar.gz
mv cifar-10-batches-bin/* . && rm -rf cifar-10-batches-bin

# Creation is split out because leveldb sometimes causes segfault
# and needs to be re-created.

echo "Done."


数据下载完成,在/data/cifar文件夹下多了一些文件。这些文件无法在caffe框架下直接运行,需要转换格式



2、数据格式转换

执行/examples/cifar10目录下的create_cifar10.sh,这里需要做一些修改,我已经标记为黄色底纹
#!/usr/bin/env sh
# This script converts the cifar data into leveldb format.
set -e

LOG_FILE=./LOG.txt

EXAMPLE=./
DATA=../../data/cifar10
DBTYPE=lmdb

echo "Creating $DBTYPE..."

rm -rf $EXAMPLE/cifar10_train_$DBTYPE $EXAMPLE/cifar10_test_$DBTYPE

exec 2>>$LOG_FILE

../../Build/x64/Release/convert_cifar_data.exe $DATA $EXAMPLE $DBTYPE

echo "Computing image mean..."

../../Build/x64/Release/compute_image_mean.exe -backend=$DBTYPE \
$EXAMPLE/cifar10_train_$DBTYPE $EXAMPLE/mean.binaryproto

echo "Done."


其中LOG_FILE=./LOG.TXT和exec 2>>$LOG_FILE是为了打印错误的语句到日志文件,方便我们检查错误
执行后的结果为:





3.神经网络的训练

修改exampe/cifar10文件夹下train_quick.sh,修改后的内容如下
#!/usr/bin/env sh
set -e

CAFFE_ROOT=D:/Caffe/Caffe_BVLC
TOOLS=$CAFFE_ROOT/Build/x64/Release

exec 2>>log.txt

$TOOLS/caffe train \
--solver=$CAFFE_ROOT/examples/cifar10/cifar10_quick_solver.prototxt $@

# reduce learning rate by factor of 10 after 8 epochs
$TOOLS/caffe train \
--solver=$CAFFE_ROOT/examples/cifar10/cifar10_quick_solver_lr1.prototxt \
--snapshot=$CAFFE_ROOT/examples/cifar10/cifar10_quick_iter_4000.solverstate.h5 $@



修改 cifar10_quick_solver.prototxt、cifar10_quick_solver_lr1.prototxt,GPU or CPU根据自己的情况修改

①cifar10_quick_solver.prototxt

# reduce the learning rate after 8 epochs (4000 iters) by a factor of 10

# The train/test net protocol buffer definition

net: "D:/Caffe/Caffe_BVLC/examples/cifar10/cifar10_quick_train_test.prototxt"

# test_iter specifies how many forward passes the test should carry out.

# In the case of MNIST, we have test batch size 100 and 100 test iterations,

# covering the full 10,000 testing images.

test_iter: 100

# Carry out testing every 500 training iterations.

test_interval: 500

# The base learning rate, momentum and the weight decay of the network.

base_lr: 0.0001

momentum: 0.9

weight_decay: 0.004

# The learning rate policy

lr_policy: "fixed"

# Display every 100 iterations

display: 100

# The maximum number of iterations

max_iter: 5000

# snapshot intermediate results

snapshot: 5000

snapshot_format: HDF5

snapshot_prefix: "D:/Caffe/Caffe_BVLC/examples/cifar10/cifar10_quick"

# solver mode: CPU or GPU

solver_mode: CPU

②cifar10_quick_solver_lr1.prototxt

# reduce the learning rate after 8 epochs (4000 iters) by a factor of 10

# The train/test net protocol buffer definition

net: "D:/Caffe/Caffe_BVLC/examples/cifar10/cifar10_quick_train_test.prototxt"
# test_iter specifies how many forward passes the test should carry out.
# In the case of MNIST, we have test batch size 100 and 100 test iterations,
# covering the full 10,000 testing images.
test_iter: 100
# Carry out testing every 500 training iterations.
test_interval: 500
# The base learning rate, momentum and the weight decay of the network.
base_lr: 0.0001
momentum: 0.9
weight_decay: 0.004
# The learning rate policy
lr_policy: "fixed"
# Display every 100 iterations
display: 100
# The maximum number of iterations
max_iter: 5000
# snapshot intermediate results
snapshot: 5000
snapshot_format: HDF5
snapshot_prefix: "D:/Caffe/Caffe_BVLC/examples/cifar10/cifar10_quick"
# solver mode: CPU or GPU
solver_mode: CPU

调整好后,点击train_quick.sh进行训练,由于使用git bash没有中间结果在屏幕上显示,我将文档写到了log.txt   (exec 2>>log.txt),整个模型训练下来大约20多分钟
训练后的结果如下:精度约达到74.83%



4、模型测试

因为cifar模型中没有给我们提供测试模板,需要自己创建一个,在example/cifar10目录下新建一个文本文件,重命名为test_cifar10_quick.bat
内容如下:
..\..\Build\x64\Release\caffe.exe test -model=.\cifar10_quick_train_test.prototxt -weights=.\cifar10_quick_iter_5000.caffemodel.h5 -iterations=100

pause
这里需要注意修改cifar10_quick_train_test.prototxt的路径
执行完成的结果如下:



至此,cifar demo的训练和测试完成
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: