您的位置:首页 > Web前端

【opencv】opencv3.3的DNN模块调用caffe训练结果——【caffe学习五】

2017-11-04 09:40 399 查看
继上篇http://blog.csdn.net/qq_15947787/article/details/78441232

使用googlenet训练好caffe model之后,需要在opencv中调用训练结果,恰好opencv3.3提供了dnn模块。

1.OpenCV3.3 DNN模块介绍

在OpenCV3.3版本发布中把DNN模块从扩展模块移到了OpenCV正式发布模块中,当前DNN模块最早来自Tiny-dnn,可以加载预先训练好的Caffe模型数据,OpenCV做了近一步扩展支持所有主流的深度学习框架训练生成与导出模型数据加载,常见的有如下:

Caffe

TensorFlow

Torch/PyTorch 

OpenCV中DNN模块已经支持与测试过这些常见的网络模块:

AlexNet
GoogLeNet v1 (also referred to as Inception-5h)

ResNet-34/50/...

SqueezeNet v1.1

VGG-based FCN (semantical segmentation network)

ENet (lightweight semantical segmentation network)

VGG-based SSD (object detection network)

MobileNet-based SSD (light-weight object detection network)

OpenCV中DNN模块的位置:

opencv3.3\sources\samples\dnn



函数和框架

下面是我们将用到的一些函数。

在dnn中从磁盘加载图片:

cv2.dnn.blobFromImage

cv2.dnn.blobFromImages

用“create”方法直接从各种框架中导出模型:

cv2.dnn.createCaffeImporter

cv2.dnn.createTensorFlowImporter

cv2.dnn.createTorchImporter

使用“读取”方法从磁盘直接加载序列化模型:

cv2.dnn.readNetFromCaffe

cv2.dnn.readNetFromTensorFlow

cv2.dnn.readNetFromTorch

cv2.dnn.readhTorchBlob

从磁盘加载完模型之后,可以用.forward方法来向前传播我们的图像,获取分类结果。

2.OpenCV3.3 dnn模块调用caffe model

以D:\opencv3.3\sources\samples\dnn\caffe_googlenet.cpp为例:

/**M///////////////////////////////////////////////////////////////////////////////////////
//
//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
//  By downloading, copying, installing or using the software you agree to this license.
//  If you do not agree to this license, do not download, install,
//  copy or use the software.
//
//
//                           License Agreement
//                For Open Source Computer Vision Library
//
// Copyright (C) 2013, OpenCV Foundation, all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
//   * Redistribution's of source code must retain the above copyright notice,
//     this list of conditions and the following disclaimer.
//
//   * Redistribution's in binary form must reproduce the above copyright notice,
//     this list of conditions and the following disclaimer in the documentation
//     and/or other materials provided with the distribution.
//
//   * The name of the copyright holders may not be used to endorse or promote products
//     derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/
#include <opencv2/dnn.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
#include <opencv2/core/utils/trace.hpp>
using namespace cv;
using namespace cv::dnn;

#include <fstream>
#include <iostream>
#include <cstdlib>
using namespace std;

/* Find best class for the blob (i. e. class with maximal probability) */
static void getMaxClass(const Mat &probBlob, int *classId, double *classProb)
{
Mat probMat = probBlob.reshape(1, 1); //reshape the blob to 1x1000 matrix
Point classNumber;

minMaxLoc(probMat, NULL, classProb, NULL, &classNumber);
*classId = classNumber.x;
}

static std::vector<String> readClassNames(const char *filename = "synset_words.txt")
{
std::vector<String> classNames;

std::ifstream fp(filename);
if (!fp.is_open())
{
std::cerr << "File with classes labels not found: " << filename << std::endl;
exit(-1);
}

std::string name;
while (!fp.eof())
{
std::getline(fp, name);
if (name.length())
classNames.push_back( name.substr(name.find(' ')+1) );
}

fp.close();
return classNames;
}

int main(int argc, char **argv)
{
CV_TRACE_FUNCTION();

String modelTxt = "bvlc_googlenet.prototxt";
String modelBin = "bvlc_googlenet.caffemodel";
String imageFile = (argc > 1) ? argv[1] : "space_shuttle.jpg";

Net net;
try {
//! [Read and initialize network]
net = dnn::readNetFromCaffe(modelTxt, modelBin);
//! [Read and initialize network]
}
catch (cv::Exception& e) {
std::cerr << "Exception: " << e.what() << std::endl;
//! [Check that network was read successfully]
if (net.empty())
{
std::cerr << "Can't load network by using the following files: " << std::endl;
std::cerr << "prototxt:   " << modelTxt << std::endl;
std::cerr << "caffemodel: " << modelBin << std::endl;
std::cerr << "bvlc_googlenet.caffemodel can be downloaded here:" << std::endl;
std::cerr << "http://dl.caffe.berkeleyvision.org/bvlc_googlenet.caffemodel" << std::endl;
exit(-1);
}
//! [Check that network was read successfully]
}

//! [Prepare blob]
Mat img = imread(imageFile);
if (img.empty())
{
std::cerr << "Can't read image from the file: " << imageFile << std::endl;
exit(-1);
}

//GoogLeNet accepts only 224x224 BGR-images
Mat inputBlob = blobFromImage(img, 1.0f, Size(224, 224),
Scalar(104, 117, 123), false);   //Convert Mat to batch of images
//! [Prepare blob]

Mat prob;
cv::TickMeter t;
for (int i = 0; i < 10; i++)
{
CV_TRACE_REGION("forward");
//! [Set input blob]
net.setInput(inputBlob, "data");        //set the network input
//! [Set input blob]
t.start();
//! [Make forward pass]
prob = net.forward("prob");                          //compute output
//! [Make forward pass]
t.stop();
}

//! [Gather output]
int classId;
double classProb;
getMaxClass(prob, &classId, &classProb);//find the best class
//! [Gather output]

//! [Print results]
std::vector<String> classNames = readClassNames();
std::cout << "Best class: #" << classId << " '" << classNames.at(classId) << "'" << std::endl;
std::cout << "Probability: " << classProb * 100 << "%" << std::endl;
//! [Print results]
std::cout << "Time: " << (double)t.getTimeMilli() / t.getCounter() << " ms (average from " << t.getCounter() << " iterations)" << std::endl;

return 0;
} //main


需要修改的位置:

位置1

Mat inputBlob = blobFromImage(img, 1.0f, Size(224, 224),
Scalar(104, 117, 123), false);
Scalar(104,117,123)为make_imagenet_mean.sh计算出来的均值

—————————————————————————————————————————

位置2

static std::vector<String> readClassNames(const char *filename = "synset_words.txt")
标签文件名

格式为:

0 合格
1 优秀
2 良好
3 不合格

—————————————————————————————————————————

位置3

String modelTxt = "bvlc_googlenet.prototxt";
String modelBin = "bvlc_googlenet.caffemodel";
String imageFile = (argc > 1) ? argv[1] : "space_shuttle.jpg";
googlenet路径,图片路径

运行结果:

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: