您的位置:首页 > Web前端

使用Caffe基于cifar10进行物体识别

2017-06-10 18:31 405 查看
http://blog.csdn.net/fengbingchun/article/details/72953284中对cifar10进行train,这里通过train得到的model,对图像进行识别。cifar10数据集共包括10类,按照0到9的顺序依次为airplane(飞机)、automobile(轿车)、bird(鸟)、cat(猫)、deer(鹿)、dog(狗)、frog(青蛙)、horse(房子)、ship(船)、truck(卡车)。
在识别前需要对原有的cifar10_quick_train_test.prototxt文件进行调整,调整后的内容如下:
name: "CIFAR10_quick"
layer {
name: "data"
type: "MemoryData"
top: "data"
top: "label"
memory_data_param {
batch_size: 1
channels: 3
height: 32
width: 32
}
}
layer {
name: "conv1"
type: "Convolution"
bottom: "data"
top: "conv1"
param {
lr_mult: 1
}
param {
lr_mult: 2
}
convolution_param {
num_output: 32
pad: 2
kernel_size: 5
stride: 1
weight_filler {
type: "gaussian"
std: 0.0001
}
bias_filler {
type: "constant"
}
}
}
layer {
name: "pool1"
type: "Pooling"
bottom: "conv1"
top: "pool1"
pooling_param {
pool: MAX
kernel_size: 3
stride: 2
}
}
layer {
name: "relu1"
type: "ReLU"
bottom: "pool1"
top: "pool1"
}
layer {
name: "conv2"
type: "Convolution"
bottom: "pool1"
top: "conv2"
param {
lr_mult: 1
}
param {
lr_mult: 2
}
convolution_param {
num_output: 32
pad: 2
kernel_size: 5
stride: 1
weight_filler {
type: "gaussian"
std: 0.01
}
bias_filler {
type: "constant"
}
}
}
layer {
name: "relu2"
type: "ReLU"
bottom: "conv2"
top: "conv2"
}
layer {
name: "pool2"
type: "Pooling"
bottom: "conv2"
top: "pool2"
pooling_param {
pool: AVE
kernel_size: 3
stride: 2
}
}
layer {
name: "conv3"
type: "Convolution"
bottom: "pool2"
top: "conv3"
param {
lr_mult: 1
}
param {
lr_mult: 2
}
convolution_param {
num_output: 64
pad: 2
kernel_size: 5
stride: 1
weight_filler {
type: "gaussian"
std: 0.01
}
bias_filler {
type: "constant"
}
}
}
layer {
name: "relu3"
type: "ReLU"
bottom: "conv3"
top: "conv3"
}
layer {
name: "pool3"
type: "Pooling"
bottom: "conv3"
top: "pool3"
pooling_param {
pool: AVE
kernel_size: 3
stride: 2
}
}
layer {
name: "ip1"
type: "InnerProduct"
bottom: "pool3"
top: "ip1"
param {
lr_mult: 1
}
param {
lr_mult: 2
}
inner_product_param {
num_output: 64
weight_filler {
type: "gaussian"
std: 0.1
}
bias_filler {
type: "constant"
}
}
}
layer {
name: "ip2"
type: "InnerProduct"
bottom: "ip1"
top: "ip2"
param {
lr_mult: 1
}
param {
lr_mult: 2
}
inner_product_param {
num_output: 10
weight_filler {
type: "gaussian"
std: 0.1
}
bias_filler {
type: "constant"
}
}
}
layer {
name: "prob"
type: "Softmax"
bottom: "ip2"
top: "prob"
} 可视化结果如下图(https://ethereon.github.io/netscope/quickstart.html):


从网上找了10幅较标准的图像,如下:


测试代码如下:#include "funset.hpp"
#include "common.hpp"

int cifar10_predict()
{
#ifdef CPU_ONLY
caffe::Caffe::set_mode(caffe::Caffe::CPU);
#else
caffe::Caffe::set_mode(caffe::Caffe::GPU);
#endif

const std::string param_file{ "E:/GitCode/Caffe_Test/test_data/model/cifar10/cifar10_quick_train_test_.prototxt" };
const std::string trained_filename{ "E:/GitCode/Caffe_Test/test_data/model/cifar10/cifar10_quick_iter_4000.caffemodel.h5" };
const std::string image_path{ "E:/GitCode/Caffe_Test/test_data/images/object_recognition/" };
const std::string mean_file{"E:/GitCode/Caffe_Test/test_data/model/cifar10/mean.binaryproto"};

caffe::Net<float> caffe_net(param_file, caffe::TEST);
caffe_net.CopyTrainedLayersFromHDF5(trained_filename);

const boost::shared_ptr<caffe::Blob<float> > blob_by_name = caffe_net.blob_by_name("data");
int image_channel = blob_by_name->channels();
int image_height = blob_by_name->height();
int image_width = blob_by_name->width();

int num_outputs = caffe_net.num_outputs();
const std::vector<caffe::Blob<float>*>& output_blobs = caffe_net.output_blobs();
int require_blob_index{ -1 };
const int digit_category_num{ 10 };
for (int i = 0; i < output_blobs.size(); ++i) {
if (output_blobs[i]->count() == digit_category_num)
require_blob_index = i;
}
if (require_blob_index == -1) {
fprintf(stderr, "ouput blob don't match\n");
return -1;
}

std::vector<int> target{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
std::vector<int> result;

// read mean data
caffe::BlobProto image_mean; // storage order: rr..rrgg..ggbb..bb
if (!caffe::ReadProtoFromBinaryFile(mean_file, &image_mean)) {
fprintf(stderr, "parse mean file fail\n");
return -1;
}

if (image_channel != image_mean.channels() || image_height != image_mean.height() || image_width != image_mean.width() ||
image_channel != 3) {
fprintf(stderr, "their dimension dismatch\n");
return -1;
}

cv::Mat mat_mean(image_height, image_width, CV_32FC3, const_cast<float*>(image_mean.data().data()));

for (auto num : target) {
std::string str = std::to_string(num);
str += ".jpg";
str = image_path + str;

cv::Mat mat = cv::imread(str.c_str(), 1);
if (!mat.data) {
fprintf(stderr, "load image error: %s\n", str.c_str());
return -1;
}

if (image_channel == 1)
cv::cvtColor(mat, mat, CV_BGR2GRAY);
else if (image_channel == 4)
cv::cvtColor(mat, mat, CV_BGR2BGRA);

cv::resize(mat, mat, cv::Size(image_width, image_height));
mat.convertTo(mat, CV_32FC3);

// Note: need to subtract mean
std::vector<cv::Mat> mat_tmp2; //b,g,r
cv::split(mat, mat_tmp2);
cv::Mat mat_tmp3(image_height, image_width, CV_32FC3);
float* p = (float*)mat_tmp3.data;
memcpy(p, mat_tmp2[2].data, image_height * image_width * sizeof(float));
memcpy(p + image_height * image_width, mat_tmp2[1].data, image_height * image_width * sizeof(float));
memcpy(p + image_height * image_width * 2, mat_tmp2[0].data, image_height * image_width * sizeof(float));
cv::subtract(mat_tmp3, mat_mean, mat_tmp3);

boost::shared_ptr<caffe::MemoryDataLayer<float> > memory_data_layer =
boost::static_pointer_cast<caffe::MemoryDataLayer<float>>(caffe_net.layer_by_name("data"));
float dummy_label[1] {0};
memory_data_layer->Reset((float*)(mat_tmp3.data), dummy_label, 1); // rr..rrgg..ggbb..bb

float loss{ 0.0 };
const std::vector<caffe::Blob<float>*>& results = caffe_net.ForwardPrefilled(&loss); // Net forward
const float* output = results[require_blob_index]->cpu_data();

float tmp{ -1 };
int pos{ -1 };

fprintf(stderr, "actual digit is: %d\n", target[num]);
for (int j = 0; j < 10; j++) {
printf("Probability to be Number %d is: %.3f\n", j, output[j]);
if (tmp < output[j]) {
pos = j;
tmp = output[j];
}
}

result.push_back(pos);
}

for (auto i = 0; i < 10; i++)
fprintf(stderr, "actual digit is: %d, result digit is: %d\n", target[i], result[i]);

fprintf(stderr, "predict finish\n");

return 0;
} 测试结果如下:

其中鹿和青蛙识别错误。 GitHubhttps://github.com/fengbingchun/Caffe_Test
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: