您的位置:首页 > 编程语言 > Python开发

caffe的python layer

2017-07-27 15:57 225 查看

在prototxt中python layer的定义:

layer {
name: "data"
type: "Python" #type必须是python, caffe编译的时候config里面 with python layer参数要打开
top: "data"
top: "label"
include {
phase: TRAIN
}
python_param {
module: "caffe_python_layer" #module就是我们python层文件的名字
layer: "RoadSegLayer"        #PyLayer就写python文件里面的class的名字
#param_str是一个字符串,里面是字典的格式
param_str: "{  'data_dir' : '../image_dir/', 'mean' : 0, 'split' : 'list/train.txt', 'randomize' : True, 'batch_size' : 32  }"
}
}


python layer

这里我实现了一层读取npy的数据层:
#########################python layer########################
import caffe
import numpy as np
from PIL import Image
import random

#只需要完成 setup(), reshape(), forward(), backword() 这几个函数就行,由于这个是数据层,所以不需要backward。
class PyLayer(caffe.Layer):
"""
Load feature vector
"""
def setup(self, bottom, top):
"""
:param bottom:
:param top:
"""
param = eval(self.param_str) #self.param_str为我们在网络配置文件prototxt中定义的这一层的参数
#解析参数
self.data_dir = param['data_dir']
self.mean = np.array(param['mean'])
self.split = param['split']
self.seed = param.get('seed', None)
self.random = param.get('randomize', True)
self.batch_size = param.get('batch_size', 1)

#数据层一般是2个输出,即data和label
if len(top) != 2:
raise Exception("Need to define two tops: data and label.")
#数据层没有bottom
if len(bottom) != 0:
raise Exception("Do not define a bottom.")

#self.split是我们data层那个图片的list
#self.data_dir是存放图片的文件夹的路径
#split_f = "{}/{}".format(self.data_dir, self.split)
split_f = self.split
self.indices = open(split_f, 'r').read().splitlines()

self.idx = []
if 'train' not in self.split: #非训练模式下不会随机打乱list
self.random = False
if self.random:
random.seed(self.seed)
self.get_idx()
else:
self.idx = range(0, self.batch_size)

def reshape(self, bottom, top):

#这里就用来读入特征了
self.data = self.load_vector(self.idx)
self.label = self.load_labels(self.idx)

top[0].reshape(*self.data.shape)
top[1].reshape(*self.label.shape)

#前向
def forward(self, bottom, top):
top[0].data[...] = self.data
top[1].data[...] = self.label
if self.random:
self.get_idx()
else:
self.idx = [x+self.batch_size for x in self.idx]#need mod
if self.idx[0] >= len(self.indices):
for index in range(len(self.idx)):
self.idx[index] = self.idx[index] - len(self.indices)

#self.idx = range(0, self.batch_size)
#数据层没有后向
def backward(self, bottom, top):
pass

def load_vector(self, idx):

img = []
for ind in idx:
file_name = "{}/{}".format(self.data_dir, self.indices[ind%len(self.indices)].split()[0])
#print file_name
#这里我读取的是npy的文件,读取txt之类的都可以,reshape好就行
output = np.load(file_name)
output = output.reshape((np.shape(output)[1] , np.shape(output)[2], np.shape(output)[3]))
img.append(output)
return np.array(img)

def load_labels(self, idx):
label = []
for ind in idx:
label_num = int(self.indices[ind].split()[1])
label.append(label_num)
return np.array(label, dtype=np.int32)

def get_idx(self):
idx = []
for i in range(0, self.batch_size):
idx.append(random.randint(0, len(self.indices)-1))
self.idx = idx
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: