您的位置:首页 > 大数据 > 人工智能

Retrain a tensorflow model based on Inception v3

2017-06-21 10:43 381 查看
本文在谷歌2015_CVPR Inception v3模型的基础上,结合花朵识别的具体问题重新训练该模型,以获取自己需要的tensorflow模型。

重新训练Inception v3实质是在原有模型输出层后,新加了一个输出层作为最终的输出层,我们只训练这个新加的输出层。这里使用了迁移学习的概念。

Transfer learning, which means we are starting with a model that has been already trained on another problem. We will then be retraining it on a similar problem. Deep learning from scratch can take days, but transfer learning can be done in short order.

准备

本节主要给出了训练tensorflow模型的一些前提条件。

硬件环境

Ubuntu 16.04

安装tensorflow

参考tensorflow Github进行安装。

安装git

$ sudo apt-get update
$ sudo apt-get install git


准备训练样本

$ cd ~
$ mkdir tf_files
$ cd tf_files
$ curl -O http://download.tensorflow.org/example_images/flower_photos.tgz $ tar xzf flower_photos.tgz
$ ls flower_photos


flower_photos.tgz有218MB。

[可选操作]

$ cd ~/tf_files
$ ls flower_photos/roses | wc -l
$ rm flower_photos/*/[3-9]*  # 删除70%的样本数量,减少训练时间。
$ ls flower_photos/roses | wc -l


开始训练

下载retrain脚本

该脚本会自动下载google Inception v3 模型相关文件。

$ cd ~/tf_files
$ curl -O https://raw.githubusercontent.com/tensorflow/tensorflow/r1.1/tensorflow/examples/image_retraining/retrain.py[/code] 

启动tensorboard

$ cd ~/tf_files
$ tensorboard --logdir training_summaries &


Note:

This command will fail with the following error if you already have a tensorboard running:

ERROR:tensorflow:TensorBoard attempted to bind to port 6006, but it was already in use

You can kill all existing TensorBoard instances with:
$ pkill -f "tensorboard"


启动训练脚本

$ cd ~/tf_files
$ python retrain.py \
--bottleneck_dir=bottlenecks \
--how_many_training_steps=500 \
--model_dir=inception \
--summaries_dir=training_summaries/basic \
--output_graph=retrained_graph.pb \
--output_labels=retrained_labels.txt \
--image_dir=flower_photos


如果不添加
--how_many_training_steps=500
,默认值为4000。

启动浏览器查看tensorboard

等待
~/tf_files/bottlenecks
中的bottlenecks文件生成结束后,可以启动浏览器,在地址栏中输入
localhost:6006
并回车,来查看训练进度。

小结

The retraining script will write out a version of the Inception v3 network with a final layer retrained to your categories to
tf_files/retrained_graph.pb
and a text file containing the labels to
tf_files/retrained_labels.txt
.

该图像识别模型,训练后的图像识别准确率应该在85%到99%。

测试重新训练的模型

$ cd ~/tf_files
$ curl -L https://goo.gl/3lTKZs > label_image.py
$ python label_image.py flower_photos/roses/2414954629_3708a1a04d.jpg


你应该看到类似以下的结果:

daisy (score = 0.99071)
sunflowers (score = 0.00595)
dandelion (score = 0.00252)
roses (score = 0.00049)
tulips (score = 0.00032)


参考

TensorFlow For Poets
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息