您的位置:首页 > 其它

tensorflow官方教程:运用模型对类别进行预测

2018-01-16 13:54 495 查看

tensorflow官方教程:运用模型对类别进行预测

本文主要包含如下内容:

tensorflow官方教程运用模型对类别进行预测
python版本

C代码

  本教程将会教你如何使用Inception-v3。你将学会如何用Python或者C++把图像分为1000个类别.

python版本

  本段代码为tensorflow的教程代码.在开始运用模型Inception-v3对图像类别进行预测之前, 需要下载tensorflow/model.

  该python代码位于:models/tutorials/image/imagenet/classify_image.py中,执行代码即可进行预测:

cd models/tutorials/image/imagenet
python classify_image.py

# 测试结果如下:
giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca (score = 0.89107)
indri, indris, Indri indri, Indri brevicaudatus (score = 0.00779)
lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens (score = 0.00296)
custard apple (score = 0.00147)
earthstar (score = 0.00117)


  


  其中, classify_image.py的核心代码为加载模型/前向传播预测结果

def create_graph():
"""Creates a graph from saved GraphDef file and returns a saver."""       # 加载模型
# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(os.path.join(
FLAGS.model_dir, 'classify_image_graph_def.pb'), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')

def run_inference_on_image(image):      # 前向传播网络,预测图像类别
"""Runs inference on an image.

Args:
image: Image file name.

Returns:
Nothing
"""
if not tf.gfile.Exists(image):
tf.logging.fatal('File does not exist %s', image)
image_data = tf.gfile.FastGFile(image, 'rb').read()       # 读入图像数据

# Creates graph from saved GraphDef.      加载模型
create_graph()

with tf.Session() as sess:
# Some useful tensors:
# 'softmax:0': A tensor containing the normalized prediction across
#   1000 labels.
# 'pool_3:0': A tensor containing the next-to-last layer containing 2048
#   float description of the image.
# 'DecodeJpeg/contents:0': A tensor containing a string providing JPEG
#   encoding of the image.
# Runs the softmax tensor by feeding the image_data as input to the graph.
softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')         # 捕获输出
predictions = sess.run(softmax_tensor,
{'DecodeJpeg/contents:0': image_data})       # 前向传播
predictions = np.squeeze(predictions)

# Creates node ID --> English string lookup.
node_lookup = NodeLookup()      # 获得ID对应类别

top_k = predictions.argsort()[-FLAGS.num_top_predictions:][::-1]
for node_id in top_k:
human_string = node_lookup.id_to_string(node_id)
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))        # 打印结果

# 代码思路:首先读入输入图像,加载测试模型,然后前向传播捕获对应输出,并打印对应结果。


C++代码

  对应的C++代码位于
/tensorflow/tensorflow/examples/label_image/main.cc
  参考网站

  要完成对图像的预测,首先需要下载模型,将网址复制到网站上下载网络模型, 然后将其解压到指定目录:

https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz tar -zxvf inception_v3_2016_08_28_frozen.pb.tar.gz -C tensorflow/examples/label_image/data


  接下来,运用tensorflow源码进行编译,在终端中编译例子步,生成并执行二进制可执行文件:

bazel build tensorflow/examples/label_image/...
bazel-bin/tensorflow/examples/label_image/label_image


  它使用了框架自带的示例图片,输出的结果大致是这样:

I tensorflow/examples/label_image/main.cc:250] military uniform (653): 0.834306
I tensorflow/examples/label_image/main.cc:250] mortarboard (668): 0.0218695
I tensorflow/examples/label_image/main.cc:250] academic gown (401): 0.0103581
I tensorflow/examples/label_image/main.cc:250] pickelhaube (716): 0.00800814
I tensorflow/examples/label_image/main.cc:250] bulletproof vest (466): 0.00535085


  这里,我们使用的默认图像是 Admiral Grace Hopper,网络模型正确地识别出她穿着一套军服,分数高达0.8。

  
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: 
相关文章推荐