您的位置:首页 > Web前端

从零开始山寨Caffe·拾:IO系统(三)

2016-03-24 16:06 239 查看

数据变形

IO(二)中,我们已经将原始数据缓冲至Datum,Datum又存入了生产者缓冲区,不过,这离消费,还早得很呢。

在消费(使用)之前,最重要的一步,就是数据变形。

ImageNet

ImageNet提供的数据相当Raw,不仅图像尺寸不一,ROI焦点内容比例也不一,如图:

template<typename Dtype>
void DataTransformer<Dtype>::transform(const Datum& datum, Dtype* shadow_data){
//    pixel can be compressed as a string
//    cause each pixel ranges from 0~255 (a char)
const string& data = datum.data();
const int datum_channels = datum.channels();
const int datum_height = datum.height();
const int datum_width = datum.width();
const int crop_size = param.crop_size();
const Dtype scale = param.scale();
const bool must_mirror = param.mirror(); //need rand!!!
const bool has_mean_file = param.has_mean_file();
const bool has_uint8 = data.size() > 0; //pixels are compressed as a string
const bool has_mean_value = mean_vals.size() > 0;
CHECK_GT(datum_channels, 0);
CHECK_GE(datum_height, crop_size);
CHECK_GE(datum_width, crop_size);
Dtype *mean = NULL;
if (has_mean_file){
CHECK_EQ(datum_channels, mean_blob.channels());
CHECK_EQ(datum_height, mean_blob.height());
CHECK_EQ(datum_width, mean_blob.width());
mean = mean_blob.mutable_cpu_data();
}
if (has_mean_value){
CHECK(mean_vals.size() == 1 || mean_vals.size() == datum_channels)
<< "Channel's mean value must be provided as a single value or as many as channels.";
//replicate
if (datum_channels > 1 && mean_vals.size() == 1)
for (int i = 0; i < datum_channels - 1; i++)
mean_vals.push_back(mean_vals[0]);
}
int h_off = 0, w_off = 0, height = datum_height, width = datum_width;
if (crop_size){
height = crop_size;
width = crop_size;
//    train phase using random croping
if (phase == TRAIN){
h_off = rand(datum_height - height + 1);
w_off = rand(datum_width - width + 1);
}
//    test phase using expected croping
else{
h_off = (datum_height - height) / 2;
w_off = (datum_width - width) / 2;
}
}
Dtype element;
int top_idx, data_idx;
//copy datum values to shadow_data-> batch
for (int c = 0; c < datum_channels; c++){
for (int h = 0; h < height; h++){
for (int w = 0; w < width; w++){
data_idx = (c*datum_height + h_off + h)*datum_width + w_off + w;
if (must_mirror)    top_idx = (c*height + h)*width + (width - 1 - w); //top_left=top_right
else    top_idx = (c*height + h)*width + w;
if (has_uint8){
//    char type can not cast to Dtype directly
//    or will generator mass negative number(facing Cifar10)
element=static_cast<Dtype>(static_cast<uint8_t>(data[data_idx]));
}
else element = datum.float_data(data_idx);    //Dtype <- float
if (has_mean_file) shadow_data[top_idx] = (element - mean[data_idx])*scale;
else if (has_mean_value) shadow_data[top_idx] = (element - mean_vals[c])*scale;
else shadow_data[top_idx] = element*scale;
}
}
}
}


DataTransformer::transform()
上面是几种transform的核心操作,还是比较冗繁的。

首先从Datum获得输入数据尺寸,做Random-Crop。

在训练阶段,得到基于原图的两个偏移h_off,w_off。

在测试阶段,默认没有实现[Krizhevsky12]的10个测试区域多重预测,只提供单中心crop区域。

需要根据具体要求,重写这部分代码。比如GoogleNet就扩大到了144个测试区域,具体见[Szegedy14]

接着,逐通道、逐像素(crop之后的宽高):

data_idx由crop位置+偏移位置联合而成,代表原图的像素位置。

top_idx代表的是crop图的位置。

如果需要镜像(反转width轴),在计算top_idx的最后,用(width - 1 - w)替代w。

uint8这里需要特别注意:

string里的字符类型是char,而uint8是unsigned char,需要强制转换。

诸如MNIST、Cifar10这样的数据集,像素单元是以uint8存储的。

8Bit的顶位用于存储符号位,unit8范围是[0,255],int8范围是[-127,127]。

如果不转换,从char(string)中获取的值,顶位将用于符号,显然不能表达我们的像素要求。

最后,均值和缩放可以在一行完成。

template<typename Dtype>
void DataTransformer<Dtype>::transform(const Datum& datum, Blob<Dtype>* shadow_blob){
const int num = shadow_blob->num();
const int channels = shadow_blob->channels();
const int height = shadow_blob->height();
const int width = shadow_blob->width();
CHECK_EQ(channels, datum.channels());
CHECK_GE(num, 1);
CHECK_LE(height, datum.height()); //allowing crop
CHECK_LE(width, datum.width());
Dtype *base_data = shadow_blob->mutable_cpu_data();
transform(datum, base_data);
}


这个transform的重载函数是对Blob的封装。(可选)

完整代码

compute_mean.cpp

https://github.com/neopenx/Dragon/blob/master/Dragon/compute_mean.cpp

io.hpp

https://github.com/neopenx/Dragon/blob/master/Dragon/include/io.hpp

data_transformer.hpp

https://github.com/neopenx/Dragon/blob/master/Dragon/data_include/data_transformer.hpp

data_transformer.cpp

https://github.com/neopenx/Dragon/blob/master/Dragon/data_src/data_transformer.cpp
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: