TensorFlow 对象检测 API 教程 4
2018-02-11 03:01
495 查看
TensorFlow 对象检测 API 教程 - 第4部分:训练模型
在本教程中,认为已经选择了预先训练的模型,找到了现有的数据集或创建了自己的数据集,并将其转换为TFRecord文件。现在准备好训练自己模型。
一. 模型配置文件
如果你以前有转移学习的经验,可能会产生一个自从本教程第二部分以来一直徘徊的问题。那个问题是,如何修改被设计为在COCO数据集的90个类上工作的预先训练的模型,以处理新数据集的
X个类?要在
object detection API之前完成,必须删除网络的最后 90 个神经元分类层,并将其替换为新的图层。下面显示了
TensorFlow中的一个示例。
# Assume fc_2nd_last is the 2nd_last fully connected layer in your network and nb_classes is the number of classes in your new dataset. shape = (fc_2nd_last.get_shape().as_list()[-1], nb_classes) fc_last_W = tf.Variable(tf.truncated_normal(shape, stddev=1e-2)) fc_last_b = tf.Variable(tf.zeros(nb_classes)) logits = tf.nn.xw_plus_b(fc_2nd_last, fc_last_W, fc_last_b)
要使用
object detection API来实现这一点,只需修改模型配置文件中的一行代码即可。在克隆
TensorFlow models1的位置,进入到
object_detection/samples/configs目录。在此文件夹中,可以找到所有预先训练的模型的配置文件。
复制所选模型的配置文件,并将其移动到一个新文件夹,并在其中执行所有训练。在这个新文件夹中,创建一个名为
data的文件夹并将
TFRecord文件移动到其中。创建另一个名为
models的文件夹,并将所选择的预训练模型的
.ckpt(检查点)文件(其中3个)移动到此文件夹中。回想一下,
model_detection_zoo.md包含每个预先训练的模型的下载链接,这里的每个模型的下载将不仅包含
.pb文件(在教程第1部分的
jupyter notebook中使用过),还包含
.ckpt文件。在
models文件夹内创建另一个名为
train的文件夹。
二. 修改配置文件
在文本编辑器中打开新移动的配置文件,在最上面将类的数量更改为数据集中的数量。接下来,将fine_tune_checkpoint的路径更改为指向
model.ckpt文件。如果遵循模型结构,建议改为:
fine_tune_checkpoint: "models/model.ckpt"
参数
num_steps决定在完成之前将要运行多少个训练步骤。这个数字实际上取决于数据集的大小以及其他因素(包括让模型训练的时间)。一旦开始训练,建议先看看每个训练步骤需要多长时间,并相应地调整
num_steps。
接下来,需要更改训练数据集和评估数据集的
input_path和
label_map_path。
Input_path只是到自己的
TFRecord文件。在可以设置
label_map_path的路径之前,需要创建它应该指向的文件。它所要查找的是一个
.pbtxt文件,其中包含数据集每个标签的
ID和
名称。可以按照以下格式在任何文本文件中创建此文件。
item { id: 1 name: 'Green' } item { id: 2 name: 'Red' }
确保从
id:1开始,而不是
0。 建议把这个文件放在自己的数据文件夹中。最后将
num_examples设置为拥有的评估样本的数量。
三. 训练
进入object_detection文件夹并将
train.py复制到新创建的培训文件夹中。要开始训练,只需将终端窗口导航到此文件夹(确保已按照教程第1部分中的安装说明操作),然后在命令行中输入
python train.py --logtostderr --train_dir=./models/train --pipeline_config_path=rfcn_resnet101_coco.config
pipline_config_path指向配置文件。现在开始培训。当心,根据你的系统,培训可能需要几分钟的时间才能开始,所以如果它没有崩溃或停止,请给它更多的时间。
如果计算机内存不足会导致训练的失败,可以尝试多种解决方案。首先尝试添加参数
batch_queue_capacity: 2 prefetch_queue_capacity: 2
到
train_config部分的配置文件。例如,将两行放在
gradient_clipping_by_norm和
fine_tune_checkpoint之间。上面的数字
2应该只是开始训练的开始值。这些值的默认值分别是
8和
10,增加这些值应该有助于加速训练。
就是这样,现在已经开始训练,这将能够调整模型!如果想更好地了解训练的进展情况,可以考虑使用TensorBoard 。
在接下来的文章将讲述说明如何保存所训练的模型,并在项目部署了!
相关文章推荐
- TensorFlow 对象检测 API 教程2
- TensorFlow 对象检测 API 教程3
- TensorFlow 对象检测 API 教程5
- 实践操作:六步教你如何用开源框架Tensorflow对象检测API构建一个玩具检测器
- tensorflow学习笔记九:将 TensorFlow 移植到 Android手机,实现物体识别、行人检测和图像风格迁移详细教程
- Tensorflow物体检测(Object Detection)API的使用
- Android开发 API人脸检测实例教程(内含源码)
- TensorFlow全新的数据读取方式:Dataset API入门教程
- TensorFlow全新的数据读取方式:Dataset API入门教程
- 将 TensorFlow 移植到 Android手机,实现物体识别、行人检测和图像风格迁移详细教程
- 安卓手机如何玩转「动作手势检测」?有TensorFlow就够了 | 实用教程
- 百度AI攻城狮,用TensorFlow API训练目标检测模型(浣熊超可爱)
- 教程| 盯住梅西:TensorFlow目标检测实战---
- 将 TensorFlow 移植到 Android手机,实现物体识别、行人检测和图像风格迁移详细教程
- tensorflow目标检测API实现
- TensorFlow全新的数据读取方式:Dataset API入门教程
- TensorFlow全新的数据读取方式:Dataset API入门教程(转)
- 将 TensorFlow 移植到 Android手机,实现物体识别、行人检测和图像风格迁移详细教程
- tensorflow 目标检测API学习之protobuf
- 基于ubuntu16.04下anaconda中tensorflow环境的目标检测API安装