【开发日记】马桶识别之马桶分类,通过迁移学习进行马桶分类
2018-01-10 21:42
351 查看
通过上篇文章马桶识别之数据清洗,通过Resnet50清洗脏数据,得到了干净的马桶图片。但是由于评论区图片有限,每一类的图片只有300张左右。如果从头开始训练,数据量有点小。这时可以通过迁移学习,利用从ImageNet数据集中学习到的模型来进行迁移学习。
具体的做法是冻结网络(比如ResNet,GoogleNet,Vgg16等)前几层的权重和偏置,只留下最后一层的参数进行训练。其实,也是把前面几层网络当成特征提取器进行特征提取,然后再利用最后一层进行线性分类。
在Tensorflow的官方教程中,有一篇是关于迁移学习的教程,有兴趣的可以看一下:How to Retrain Inception's Final Layer for New Categories。本文正是基于该教程进行迁移学习。
1. 数据准备
数据组织结构如下图所示,创建一个总文件夹,下面放“种类”子文件夹,子文件夹下放图片
![](https://oscdn.geek-share.com/Uploads/Images/Content/201801/10/6c27e106c84277c8c9499ee61d6fe7bf)
2. 代码准备
确保~/tensorflow/examples 下有文件夹 image_retraining。如果没有的话,可以上Tensorflow的GitHub网址上下载(https://github.com/tensorflow/tensorflow)。
3. 开始训练
训练很简单,只要输入以下代码即可,默认使用预训练的 Inception V3模型。(注意更改为自己代码所在位置以及图片路径)python E:\Python35\Lib\site-packages\tensorflow\examples\image_retraining\retrain.py --image_dir F:\AI\proj\data4. 训练结果
使用我的笔记本电脑,差不多10分钟左右就可以出来结果。
训练的精度为68.4%,并不是很好,可能样本还是太少了,但是训练过程很简单,避免自己从头开始写一个神经网络。
下一步可以使用其他模型或者使用数据增强,以及调整一下其他参数来看一下精度能否继续提高。
具体的做法是冻结网络(比如ResNet,GoogleNet,Vgg16等)前几层的权重和偏置,只留下最后一层的参数进行训练。其实,也是把前面几层网络当成特征提取器进行特征提取,然后再利用最后一层进行线性分类。
在Tensorflow的官方教程中,有一篇是关于迁移学习的教程,有兴趣的可以看一下:How to Retrain Inception's Final Layer for New Categories。本文正是基于该教程进行迁移学习。
1. 数据准备
数据组织结构如下图所示,创建一个总文件夹,下面放“种类”子文件夹,子文件夹下放图片
2. 代码准备
确保~/tensorflow/examples 下有文件夹 image_retraining。如果没有的话,可以上Tensorflow的GitHub网址上下载(https://github.com/tensorflow/tensorflow)。
3. 开始训练
训练很简单,只要输入以下代码即可,默认使用预训练的 Inception V3模型。(注意更改为自己代码所在位置以及图片路径)python E:\Python35\Lib\site-packages\tensorflow\examples\image_retraining\retrain.py --image_dir F:\AI\proj\data4. 训练结果
使用我的笔记本电脑,差不多10分钟左右就可以出来结果。
训练的精度为68.4%,并不是很好,可能样本还是太少了,但是训练过程很简单,避免自己从头开始写一个神经网络。
下一步可以使用其他模型或者使用数据增强,以及调整一下其他参数来看一下精度能否继续提高。
相关文章推荐
- 【开发日记】马桶识别之马桶分类,增加图片数量再进行分类
- 【开发日记】马桶识别之马桶分类,利用百度人工智能定制化图像识别进行分类
- 【开发日记】马桶识别之数据清洗,通过Resnet50清洗脏数据
- 【开发日记】马桶识别之数据收集,通过Python抓取京东评论图片
- Qt简介以及如何配置Qt使用VS2010进行开发 分类: QT学习实践 2015-05-05 16:02 34人阅读 评论(0) 收藏
- keras迁移学习 使用vgg16进行手写数字识别
- 【开发日记】门没关好,通过树莓派+机器学习识别门关好没有
- CNTK API文档翻译(24)——使用深度迁移学习进行图像识别
- play framework如何进行模块化开发--学习笔记(借鉴同事、博客等资料自己试验通过!)
- TensorFlow迁移学习-使用谷歌训练好的Inception-v3网络进行分类
- Android学习(十六) 通过GestureDetector进行手势识别
- OpenCV3与深度学习实例-使用GoogLeNet模型进行图片分类识别
- 【J2EE核心开发学习笔记001】通过JDBC进行简单的增删改查(以MySQL为例)
- 【神经网络与深度学习】【Qt开发】【VS开发】从caffe-windows-visual studio2013到Qt5.7使用caffemodel进行分类的移植过程
- 【神经网络与深度学习】【Qt开发】【VS开发】从caffe-windows-visual studio2013到Qt5.7使用caffemodel进行分类的移植过程<二>
- 通过Ajax方式上传文件,使用FormData进行Ajax请求 博客分类: RESTful Web ServicesWeb前端开发
- play framework如何进行模块化开发--学习笔记(借鉴同事、博客等资料自己试验通过!)
- Android学习(十六) 通过GestureOverlayView进行手势识别
- 【开发日记】马桶型号识别
- 利用SpiderMonkey进行嵌入式开发——学习总结