Caffe2 - (四) 基于 squeezenet 分类的模型测试
2017-12-29 11:35
369 查看
Caffe2 模型加载与测试
Model Zoo这里以 squeezenet 模型为例,对图片中的 object 分类.
下载训练好的模型:
python -m caffe2.python.models.download -i squeezenet
模型加载:
读取 protobuf 文件:
with open("init_net.pb") as f: init_net = f.read() with open("predict_net.pb") as f: predict_net = f.read()
采用 workspace.Predictor 函数从 protobufs 加载 blobs:
p = workspace.Predictor(init_net, predict_net)
运行 net 并得到结果:
results = p.run([img])
results 是多维数组的形式,存储概率值.
每一行是识别 object 属于某一类的概率.
完整代码
# ------------------------------- # Configuration # ------------------------------- CAFFE2_ROOT = "~/caffe2" CAFFE_MODELS = "~/caffe2/caffe2/python/models" # 均值文件保存到与 model 同一路径 from caffe2.proto import caffe2_pb2 import numpy as np import skimage.io import skimage.transform from matplotlib import pyplot import os from caffe2.python import core, workspace import urllib2 print("Required modules imported.") IMAGE_LOCATION = "https://cdn.pixabay.com/photo/2015/02/10/21/28/flower-631765_1280.jpg" MODEL = 'squeezenet', 'init_net.pb', 'predict_net.pb', 'ilsvrc_2012_mean.npy', 227 # codes - these help decypher the output and source from a list from AlexNet's object codes to provide an result like "tabby cat" or "lemon" depending on what's in the picture you submit to the neural network. # The list of output codes for the AlexNet models (also squeezenet) codes = "https://gist.githubusercontent.com/aaronmarkham/cd3a6b6ac071eca6f7b4a6e40e6038aa/raw/9edb4038a37da6b5a44c3b5bc52e448ff09bfe5b/alexnet_codes" print "Config set!" # ------------------------------- # Pre-processing image # ------------------------------- def crop_center(img,cropx,cropy): y,x,c = img.shape startx = x//2-(cropx//2) starty = y//2-(cropy//2) return img[starty:starty+cropy,startx:startx+cropx] def rescale(img, input_height, input_width): print("Original image shape:" + str(img.shape) + " and remember it should be in H, W, C!") print("Model's input shape is %dx%d") % (input_height, input_width) aspect = img.shape[1]/float(img.shape[0]) print("Orginal aspect ratio: " + str(aspect)) if(aspect>1): # landscape orientation - wide image res = int(aspect * input_height) imgScaled = skimage.transform.resize(img, (input_width, res)) if(aspect<1): # portrait orientation - tall image res = int(input_width/aspect) imgScaled = skimage.transform.resize(img, (res, input_height)) if(aspect == 1): imgScaled = skimage.transform.resize(img, (input_width, input_height)) pyplot.figure() pyplot.imshow(imgScaled) pyplot.axis('on') pyplot.title('Rescaled image') print("New image shape:" + str(imgScaled.shape) + " in HWC") return imgScaled print "Functions set." # set paths and variables from model choice and prep image CAFFE2_ROOT = os.path.expanduser(CAFFE2_ROOT) CAFFE_MODELS = os.path.expanduser(CAFFE_MODELS) # mean can be 128 or custom based on the model # gives better results to remove the colors found in all of the training images MEAN_FILE = os.path.join(CAFFE_MODELS, MODEL[0], MODEL[3]) if not os.path.exists(MEAN_FILE): mean = 128 else: mean = np.load(MEAN_FILE).mean(1).mean(1) mean = mean[:, np.newaxis, np.newaxis] print "mean was set to: ", mean INPUT_IMAGE_SIZE = MODEL[4] # make sure all of the files are around... if not os.path.exists(CAFFE2_ROOT): print("Houston, you may have a problem.") INIT_NET = os.path.join(CAFFE_MODELS, MODEL[0], MODEL[1]) print 'INIT_NET = ', INIT_NET PREDICT_NET = os.path.join(CAFFE_MODELS, MODEL[0], MODEL[2]) print 'PREDICT_NET = ', PREDICT_NET if not os.path.exists(INIT_NET): print(INIT_NET + " not found!") else: print "Found ", INIT_NET, "...Now looking for", PREDICT_NET if not os.path.exists(PREDICT_NET): print "Caffe model file, " + PREDICT_NET + " was not found!" else: print "All needed files found! Loading the model in the next block." # 图片读取与转换 img = skimage.img_as_float(skimage.io.imread(IMAGE_LOCATION)).astype(np.float32) img = rescale(img, INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE) img = crop_center(img, INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE) print "After crop: " , img.shape pyplot.figure() pyplot.imshow(img) pyplot.axis('on') pyplot.title('Cropped') # switch to CHW img = img.swapaxes(1, 2).swapaxes(0, 1) pyplot.figure() for i in range(3): pyplot.subplot(1, 3, i+1) pyplot.imshow(img[i]) pyplot.axis('off') pyplot.title('RGB channel %d' % (i+1)) # switch to BGR img = img[(2, 1, 0), :, :] # remove mean for better results # 减均值 img = img * 255 - mean # add batch size img = img[np.newaxis, :, :, :].astype(np.float32) print "NCHW: ", img.shape # ------------------------------- # 网络初始化 # ------------------------------- with open(INIT_NET) as f: init_net = f.read() with open(PREDICT_NET) as f: predict_net = f.read() p = workspace.Predictor(init_net, predict_net) # ------------------------------- # 网络预测 # ------------------------------- # run the net and return prediction results = p.run([img]) # # turn it into something we can play with and examine which is in a multi-dimensional array results = np.asarray(results) print "results shape: ", results.shape # results shape: (1, 1, 1000, 1, 1) # ------------------------------- # 得到最终结果 # ------------------------------- results = np.delete(results, 1) index = 0 highest = 0 arr = np.empty((0,2), dtype=object) arr[:,0] = int(10) arr[:,1:] = float(10) for i, r in enumerate(results): # imagenet index begins with 1! i=i+1 arr = np.append(arr, np.array([[i,r]]), axis=0) if (r > highest): highest = r index = i print index, " :: ", highest # lookup the code and return the result # top 3 results # sorted(arr, key=lambda x: x[1], reverse=True)[:3] # now we can grab the code list response = urllib2.urlopen(codes) # and lookup our result from the list for line in response: code, result = line.partition(":")[::2] if (code.strip() == str(index)): print result.strip()[1:-2] # output # 985 :: 0.979059 # daisy
Reference
[1] - Loading Pre-Trained Models相关文章推荐
- [MXNet Gluon]基于斯坦福狗的品种分类数据集训练SSD检测模型
- 基于scikit-learn(sklearn)做分类--3.优化--保存模型
- caffe训练模型后,使用模型测试的分类结果全部都是相同的
- 一种测试方向的探讨-基于模型测试调研引发的思考 – 1
- 基于TensorFlow和Keras的垃圾分类模型
- 第003篇:ArcGIS中基于矢量样本点制作分类训练样本和测试样本的方法。
- 一种测试方向的探讨-基于模型测试调研引发的思考 - 1
- 基于caffe特征可视化 以及 用训练好的模型进行分类 2
- 基于对评论进行分类的持续运行模型
- 基于LSTM搭建一个文本情感分类的深度学习模型:准确率往往有95%以上
- Pycaffe-简单测试caffe模型的分类效果和运行速度
- 基于谷歌draco项目的测试---三维模型数据压缩方案
- 使用ASP.NET WEB API构建基于REST风格的服务实战系列教程(一)——使用EF6构建数据库及模型
- 【经典】基于 SSL 的 ASP.NET Web 应用测试自动化
- 基于Tensorflow的英文评论二分类CNN模型
- 一种测试方向的探讨-基于模型测试调研引发的思考 - 2
- 基于BOW模型的图像分类Bag Of Visual Words model for image classification
- 基于am3358的蜂鸣器测试 分类: TI-AM335X 2015-06-10 11:15 253人阅读 评论(0) 收藏
- 转载:一个基于概念的中文文本分类模型
- 使用ASP.NET WEB API构建基于REST风格的服务实战系列教程(一)——使用EF6构建数据库及模型