Mxnet学习笔记(2)--自定义DataIter
2017-01-15 15:32
429 查看
前言
之前在GPU集群上配置的caffe因为一系列人为因素崩溃,搭建的Tensorflow由于cuda版本太低有些实验不能跑,而恰逢管理员不在,只好找一款不受这些因素影响的框架,之前阅读过mxnet源码,源码不多,很容易懂。于是配置
了一把,成功只好用了。而我的实验基于多源数据,即包含两种输入,这里是face images和audio数据,如果单纯使用
官方提供DataIter不能够完成任务,只好自己写(由于数据较大,直接使用NDArrayIter不现实,不如直接自己重新设计
一种DataIter),当然本文章着重讲解如何自定义DataIter,细节还需参看源码。
话不多说,上干货。
目录
Mxnet中的DataIterDataIter
NDArrayIter
MXDataIter
自定义DataIter
Mxnet中的DataIter
DataIter类对象都在模块io.py中,而所有的DataIter都继承于基类
DataIter,其中
DataIter源码如下:
class DataIter(object): def __init__(self): self.batch_size = 0 def __iter__(self): return self def reset(self): pass def next(self): if self.iter_next(): return mx.io.DataBatch(data=self.getdata(), label=self.getlabel(), \ pad=self.getpad(), index=self.getindex()) else: raise StopIteration def __next__(self): return self.next() def iter_next(self): pass def getdata(self): pass def getlabel(self): pass def getindex(self): return None def getpad(self): pass
由以上代码可以看出,DataIter是一个迭代器,核心部分在方法
next(),而其中涉及方法
getdata(), getlabel(), getpad(), getindex(),但这是在不
重写方法
next()前提下,我们需要提供这四个方法;而如果重写
next(),我们仅需要为
DataBatch提供batch大小的data,label即可,置于pad等方法我们可以忽略或者借鉴
其他DataIter.接下来我们就这几个方法看看官方提供DataIter干了些什么。
NDArrayIter
NDarrayIter由名字可以看出它是基于ndarray的数据迭代器,即数据来源是numpy数据。由官方文档
mxnet.io.NDArrayIter可知,
NDarrayIter参数主要如下:
Parameters:
data (NDArray or numpy.ndarray, a list of them, or a dict of string to them.) – NDArrayIter supports single or multiple data and label.
label (NDArray or numpy.ndarray, a list of them, or a dict of them.) – Same as data, but is not fed to the model during testing.
batch_size (batch_size)
shuffle (boolean)
last_batch_handle (‘pad’, ‘discard’, ‘roll_over’)
上述五个参数,最主要的是前三个,对于data我们可以提供ndarray或者NDArray,结构可以使列表,字典形式。这里为什么可以ist, dict.我们来看源码,在类NDArrayIter中数据初始化时调用了下面方法方法
def _init_data(data, allow_empty, default_name): assert (data is not None) or allow_empty if data is None: data = [] if isinstance(data, (np.ndarray, NDArray)): data = [data] if isinstance(data, list): if not allow_empty: assert(len(data) > 0) if len(data) == 1: data = OrderedDict([(default_name, data[0])]) else: data = OrderedDict([('_%d_%s' % (i, default_name), d) for i, d in enumerate(data)]) if not isinstance(data, dict): raise TypeError("Input must be NDArray, numpy.ndarray, " + \ "a list of them or dict with them as values") for k, v in data.items(): if not isinstance(v, NDArray): try: data[k] = array(v) except: raise TypeError(("Invalid type '%s' for %s, " % (type(v), k)) + \ "should be NDArray or numpy.ndarray") return list(data.items()) class NDArrayIter(DataIter): def __init__(self, data, label=None, batch_size=1, shuffle=False, last_batch_handle='pad'): super(NDArrayIter, self).__init__() self.data = _init_data(data, allow_empty=False, default_name='data') self.label = _init_data(label, allow_empty=True, default_name='softmax_label') # shuffle data if shuffle: idx = np.arange(self.data[0][1].shape[0]) np.random.shuffle(idx) self.data = [(k, array(v.asnumpy()[idx], v.context)) for k, v in self.data] self.label = [(k, array(v.asnumpy()[idx], v.context)) for k, v in self.label]
即如果我们给出data和label均为list,那么在
_init_data中处理最后会得到一个有序字典(OrderedDict),并赋予默认的name,数据为data,标签名字为softmax_label.这里注意,提供的名字必须与搭建网络时变量中的名字对应,不然会报错。
当然,如果这里data,label为字典类型时,那么在执行方法
_init_data时会使用字典中key代替defaultname。这里还是要注意name,name,name(重要的事情说三遍)。
而我们在源码中可以看到,在后面装饰器装饰的方法
provide_data,
provide_label,分别为向搭建好的sym中供应name和此次迭代中提供的数据shape.即mxnet的数据供应过程是通过名字和shape来实现的。
@property def provide_data(self): """The name and shape of data provided by this iterator""" return [(k, tuple([self.batch_size] + list(v.shape[1:]))) for k, v in self.data] @property def provide_label(self): """The name and shape of label provided by this iterator""" return [(k, tuple([self.batch_size] + list(v.shape[1:]))) for k, v in self.label]
以上便是NDArrayIter的核心部分,其余跟进阅读源码即可。
MXDataIter
MXDataIter是mxnet中的标准数据迭代器,在官方例子中没看到使用方法,同样继承与DataIter,并重写getXXX方法,方便在next()方法中
mx.io.DataBatch()提供参数。这部分需要进一步补充,以后遇到使用的例子进行补充。
自定义DataIter
这里自定义DataIter,其实和上述迭代器方式相同,继承DataIter即可,因为DataIter中构成迭代的部分主要在方法next()中,而其中集中于
DataBatch,这里我们看一下
DataBatch的结构:
class DataBatch(object): def __init__(self, data, label, pad=None, index=None, bucket_key=None, provide_data=None, provide_label=None): self.data = data self.label = label self.pad = pad self.index = index # the following properties are only used when bucketing is used self.bucket_key = bucket_key self.provide_data = provide_data self.provide_label = provide_label
可以看到,只要赋值给
DataBatch中参数data,label就可以达到数据迭代的效果,因为每个DataBatch是mxnet中默认的mini-batch数据对象。所以这里有两条路线可以走:
自定义自己的数据生成器,可以不断地生成batch_data, batch_label供给DataBatch
重写方法getXXX(),即
getdata,
getlabel,使用父类方法
DataIter.next()默认配置即可。
其实两种方法原理皆是为了想DataBatch提供数据,事不过改写不同的地方。不过在coding时需要注意以下几点:
注意
provide_data|label()方法提供的名字需要和sym构建过程中有关输入变量的name一致,并同时提供batch大小的数据的shape
如果data中包含多个输入,比如包含image, wordvec,在
pprovide_data|label中需要分开写:
return [('data1', shape1), ('data2', shape2)],同理标签如果有多个输出,也应该如此设计。
以上就是最近自定义是遇到的问题解决,后续会不断补充,有一句说一句,mxnet还是很优秀的,代码读起来很舒服。
相关文章推荐
- data-*自定义属性
- JS自定义data-*属性
- 使用HTML5中的element.dataset操作自定义data-*数据
- 【转载】HTML5自定义data属性
- yii CListView中使用CArrayDataProvider自定义数组作为数据
- HTML 5 的自定义 data-* 属性和jquery的data()方法的使用
- html5新特性data_*自定义属性使用
- Android自定义DataTimePicker(日期选择器)
- android 读取 AndroidManifest.xml 中的数据:版本号、应用名称、自定义K-V数据(meta-data)
- 浅谈HTML5的新特性——data-*自定义属性
- nRF51822 自定义UUID,ble_advdata_set的时候 NRF_ERROR_DATA_SIZE 错误的解决
- 背水一战 Windows 10 (20) - 绑定: DataContextChanged, UpdateSourceTrigger, 对绑定的数据做自定义转换
- HTML5的自定义属性data-* 的用法解析
- Html 5中自定义data-*特性
- core data UIColor转换为 自定义数据类型 (其他类型数据 转换同理)
- spring data mongodb学习以及为repository提供可扩展的自定义方法
- ADO.NET Data Service中如何自定义Operation
- GridView的HyperLinkField的DataNavigateUrlFormatString如何使用自定义的变量,而不是数据库绑定的值
- HTML自定义按钮上传图片并实现预览(同时解决getAsDataURL()弃用问题)
- HTML5 data-* 自定义属性