Caffe小玩意(3)-利用py-faster-rcnn自定义输入数据
2016-07-01 12:20
513 查看
Caffe小玩意(3)-利用py-faster-rcnn自定义输入数据
众所周知,caffe是现有deep learning framework中最为自动化的,我们甚至可以只定义prototxt文件而不需要写代码,就完成整个网络的训练。正是由于它的高度自动化,当我们想要修改其中的模块,就不是一件容易的事了。caffe本身自带了一些标准通用的dataset,我们可以比较简单地使用它们。此外,对于一些其他的输入形式,caffe也给出了一些指示:
http://caffe.berkeleyvision.org/tutorial/data.html
http://caffe.berkeleyvision.org/tutorial/layers.html#data-layers
但是,对于那种label不是简单变量的输入,我们应该怎么输入到caffe里呢?(例如:显著性检测问题,我们的label应该是一幅灰度图像;人体关节检测问题,我们的label应该是一个tensor)。那么今天,我们就来看看如何利用rbg大神的py-faster-rcnn框架来自己定制输入数据:
https://github.com/rbgirshick/py-faster-rcnn
首先,pull这个repository到本地目录(之后的./就代表在本地下的这个目录),然后运行./data/目录下的script下载数据(这些数据本身不是必要的,只是因为我之前需要finetune模型将它们下载了下来,之后的路径、操作等等也基于这一事实)。
好了,现在会多了一些目录出来。我们需要将数据(假设图像输入就是.jpg文件,label是python的numpy array,即.npy文件)放到相应的地方,即./data/VOCdevkit2007/VOC2007/JPEGImages。但是呢,这个目录下是原来的VOC2007数据集的图像输入,所以我建议在这个目录下再新建一个目录(这里叫dlib)。因此实际存放路径是:
./data/VOCdevkit2007/VOC2007/JPEGImagesd/dlib
之后,我们需要为这些输入数据写xml文件。每一份输入图像都对应一个xml文件,内容如下(不需要注重格式):
<annotation><folder>VOC2007</folder><filename>image_0046.jpg</filename><source><database>dlib facial landmark</database><annotation>Yuliang Zou</annotation></source><size><width>400</width><height>300</height><depth>3</depth></size><segmented>0</segmented></annotation>
同样地,为了与原来数据的xml文件混淆,新建一个dlib文件夹,因此这些新xml文件的存放路径为:
./data/VOCdevkit2007/VOC2007/Annotations/dlib
以上的操作,都没有对dataset进行training set与test set的区分,下面我们就来完成这件事。打开目录:
./data/VOCdevkit2007/VOC2007/ImageSets/Main
我们可以看到很多的txt文件,先把原来的trainval.txt与test.txt备份好。然后,新建自己的trainval.txt与test.txt,每一行都是输入图像的名称,例:
dlib/100032540_1 dlib/1002681492_1 dlib/1004467229_1 ...
(直接用这两个txt文件的名字,是因为改用新的会有点麻烦,详见附录)
之后,我们需要修改相应的python代码使得数据可以顺利导入。
(1)在
lib/roi_data_layer/layer.py里的setup()函数,我们需要添加如下代码,为label分配空间:
top[idx].reshape(cfg.TRAIN.IMS_PER_BATCH, 68, 38, 50) self._name_to_top_map['heatmap'] = idx idx += 1
我这里的label是facial landmark,一共有68个2-d array。然后把下面的一些不需要的部分删掉(不然之后可能会报错)。
(2)在
.lib/utils/blob.py里新定义一个函数:
def heatmap_list_to_blob(hms): """ Convert a list of heat maps into a network input.""" num_hms = len(hms) blob = np.zeros((num_hms, 68, 38, 50), dtype=np.float32) for i in xrange(num_hms): hm = hms[i] blob[i] = hm.transpose((2,0,1)) return blob
这个函数可以将包含若干label的python list转换为caffe的blob数据结构。
(3)在
lib/roi_data_layer/minibatch.py里导入刚刚定义的heatmap_list_to_blob函数,然后新定义函数:
def _get_heatmap_blob(roidb): """Get a batch of heat maps""" num_images = len(roidb) hms = [] for i in xrange(num_images): hm = np.load(roidb[i]['heatmap']) hms.append(hm) # Create a blob to hold the input heat maps blob = heatmap_list_to_blob(hms) return blob
然后,在get_minibatch()函数中加入如下几行代码:
# Get the imput heat map blob, formatted for caffe hm_blob = _get_heatmap_blob(roidb) blobs['heatmap'] = hm_blob
(4)在
./lib/roi_data_layer/roidb.py的prepare_roidb()函数中这行代码之后:
roidb[i]['image'] = imdb.image_path_at(i)
加入这么一行:
roidb[i]['heatmap'] = roidb[i]['image'][0:len(roidb[i]['image'])-3] + 'npy'
相信看到这里大家也知道了,
imdb.image_path_at(i)获取的是输入图像的完整路径,我们进行些许修改就可以得到label的完整路径。
最后,我们需要修改train.prototxt,这个按自己的需要定制就可以了,比较简单,就不详述了。
在最后之后,如果要测试性能,需要自己对
./lib/fast_rcnn/test.py进行修改。这里不再详述,我相信当你成功地开始训练的时候,已经对这些内容比较了解了,可以比较容易地写出自己需要的版本。
当然,完成了以上的所有步骤之后,可能还是会出现某些问题。
1.毕竟我的xml文件比原来的简化了不少,可以按实际情况删掉相应的code(原来的代码可能会导入xml文件的一些参数,但是我省略了那些参数)
2.原来代码对于输入图像的scaling比较奇怪,那边有可能会出错。对于某些输入尺寸固定的dataset,或许你可以修改
lib/roi_data_layer/layer.py里的setup()函数,其中会有一行
top[idx].reshape(cfg.TRAIN.IMS_PER_BATCH, 3, _, _)
最后的两个参数是height和width,按需要修改。
最近折腾这个东西也折腾了很久,甚是头疼,更是加深了我对rbg大神的仰慕之情。行文有些混乱,如果有不明白的欢迎留言,大家一起交流。
附录:
./lib/datasets/factory.py这份代码负责构造dataset
line 15 - 20:
# Set up voc_<year>_<split> using selective search "fast" mode for year in ['2007', '2012']: for split in ['train', 'val', 'trainval', 'test']: name = 'voc_{}_{}'.format(year, split) __sets[name] = (lambda split=split, year=year: pascal_voc(split, year))
./experiments/scripts/faster_rcnn_end2end.sh这份bash文件负责指定训练与测试时所用的dataset
line 27 - 28:
TRAIN_IMDB="voc_2007_trainval" TEST_IMDB="voc_2007_test"
相关文章推荐
- sadojciscjsd
- MANIFEST.MF错误
- 【USACO FEB 2010 SILVER】吃巧克力(Chocolate Eating)
- 自己实现一个javascript事件模块
- Bootstrap实现遮罩层
- css-position
- jquery弹幕效果
- CSS3的column-fill属性对齐列内容高度的用法详解
- N沟道还是P沟道MOSFET
- 关于bootstrap--表单(水平表单)
- 自己整理的一些前端资料
- jsp中文字符乱码问题
- Leetcode-populating-next-right-pointers-in-each-node
- HTML的学习(第2篇)
- Angular_ng-repeat中的问题
- Less框架中将CSS强制打包到单个文件中的技巧
- CSS :hover 伪类
- html运行原理
- Angular JS filter
- jQuery 2.0.3 源码分析core - 整体架构