您的位置:首页 > 其它

生成非图像类型的LMDB数据

2018-03-15 18:14 211 查看
最近在训练网络中会用到非图像类型的数据,我这里是将这种数据转换成LMDB类型作为一个数据层,加载进网络。主要用到caffe的Python接口。

1、在网络的中间层中,其接受一个1x6维的bottom数据作为输入;

2、每个训练样本对应的1x6维的数据存储到data.txt,同时记录其类别标签;

3、写入LMDB 。

#-*- coding: UTF-8 -*-
import numpy as np
import caffe
import lmdb
from caffe.proto import caffe_pb2
import sys,os

# 读入数据和对应的类别标签
theta_file=open('./data.txt','r')
label=open('./label.txt','r')
theta_list=[]
theta_label=[]
for line in theta_file:
content=line.strip().split(',')
theta=[]
for i in range(len(content)):
theta.append(float(content[i]))
theta_list.append(theta)
del content,theta
theta_file.close()

for line in label:
content=line.strip().split('\n')
theta_label.append(int(content[0]))

# 写入lmdb,需要将list转换为array
db = lmdb.open('data_lmdb', map_size=int(1e12))
with db.begin(write=True) as in_txn:
for i in range(len(theta_list)):
datum = caffe.proto.caffe_pb2.Datum()
datum.channels = 1
datum.height = 1
datum.width = 6
tmp_=theta_list[i]
tmp=np.array(range(6), dtype=np.float)
for j in range(6):
tmp[j]=tmp_[j]
label=int(theta_label[i])
datum.data = tmp.tobytes()
# datum.data = tmp.tostring()
datum.label=label
in_txn.put('{:0>10d}'.format(i), datum.SerializeToString())
db.close()
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: