python caffe 在师兄的代码上修改成自己风格的代码
2016-01-29 17:40
513 查看
首先,感谢师兄的帮助。师兄的代码封装成类,流畅精美,容易调试。我的代码是堆积成的,被师兄嘲笑说写脚本。好吧!我的代码只有我懂,哈哈! 希望以后代码能写得工整点。现在还是让我先懂。这里,我做了一个简单的任务:0,1,2三个数字的分类。准确率:0.9806666666666667
(部分)代码分为:
1 train_net.py
2 test_net.py
3 test_data.py
4 pre_data.py
5 utils.py
6 layer/data_layer.py
7 layer/__init__.py
还有一些caffe中经典的东西没放进来。
代码和数据:
(部分)代码分为:
1 train_net.py
#import some module import time import os import numpy as np import sys import cv2 sys.path.append("/home/wang/Downloads/caffe-master/python") import caffe #from prepare_data import DataConfig #from data_config import DataConfig #configure GPU mode ''' uncommend below line to use gpu ''' caffe.set_mode_gpu() # about dataset ##dataset = Dataset('/home/wang/Downloads/object/extract/') ##dataset = dataset.Split('train') ##data_config = DataConfig(dataset) ##data_config.SetBatchSize(256) data_config='/home/wang/Downloads/caffe-master/examples/myFig_recognition/data/train/' #configure solve.prototxt solver = caffe.SGDSolver('models/solver.prototxt') # load pretrain model print('load pretrain model') solver.net.copy_from('models/bvlc_reference_caffenet.caffemodel') solver.net.layers[0].SetDataConfig(data_config) for i in range(1, 10000): # Make one SGD update solver.step(5) if i % 100 == 0: solver.net.save('tmp.caffemodel') ''' TODO: test code '''
2 test_net.py
#import setup import time import os import random import sys sys.path.append("/home/wang/Downloads/caffe-master/python") import caffe import cv2 import numpy as np import random from utils import PrepareImage #from dataset import Dataset from test_data import test_data_pre test_num_once=10 ''' uncommend below line to use gpu ''' # caffe.set_mode_gpu() # dataset #dataset = Dataset('/home/wang/Downloads/object/extract/') #dataset = dataset.Split('test') # load net net = caffe.Net('models/deploy.prototxt', caffe.TEST) # load train model print('load pretrain model') net.copy_from('tmp.caffemodel') #test all samples one by one data_pre='/home/wang/Downloads/caffe-master/examples/myFig_recognition/data/test/' #(imgPaths, gt_label) = dataset[int(random.random()*num_obj)] (imgPaths, gt_label)=test_data_pre(data_pre) num_img = len(imgPaths) correct_num=0 for idx in range(num_img): img = cv2.imread(imgPaths[idx]) img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) tmp_img = img.copy() # for display img = PrepareImage(img, (227, 227)) net.blobs['data'].reshape(test_num_once, 3, 227, 227) net.blobs['data'].data[...] = img #net.blobs['data'].data[i,:,:,:] = img net.forward() score = net.blobs['cls_prob'].data if score.argmax()==gt_label[idx]: correct_num=correct_num+1 if idx%100==0: print("Please wait some minutes...") correct_rate=correct_num*1.0/num_img print('The correct rate is :',correct_rate)
3 test_data.py
import os import numpy as np from random import randint import cv2 from utils import PrepareImage,CatImage #class data: #path should be /home/ def test_data_pre(path): img_list=[] image_num=len(os.listdir(path+'/0'))+len(os.listdir(path+'/1'))+len(os.listdir(path+'/2')) label = np.zeros(image_num, dtype=np.float32) i=0 for idf in range(3): idf_str=str(idf) path1=path+idf_str tmp_path=os.listdir(path1) for idi in range(len(tmp_path)): img_path=path1+'/'+tmp_path[idi] img_list.append(img_path) label[i]=idf i=i+1 return ( img_list,label)
4 pre_data.py
import os import numpy as np from random import randint import cv2 from utils import PrepareImage,CatImage #class data: #path should be /home/ def prepare_data(path,batchsize): #tmp_path=os.listdir(path) img_list=[] label = np.zeros(batchsize, dtype=np.float32) for i in range(batchsize): #randomly select one file idf=randint(0,2) idf_str=str(idf) path1=path+idf_str tmp_path=os.listdir(path1) #randomly select one image idi=randint(0,len(tmp_path)-1) #img = cv2.imread(imgPaths[idx]) img_path=path1+'/'+tmp_path[idi] img=cv2.imread(img_path) img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) flip = randint(0, 1)>0 if flip > 0: img = img[:, ::-1, :] # flip left to right img=PrepareImage(img, (227,227)) img_list.append(img) label[i]=idf imgData = CatImage(img_list) return (imgData,label)
5 utils.py
import os import cv2 import numpy as np def PrepareImage(im, size): im = cv2.resize(im, (size[0], size[1])) im = im.transpose(2, 0, 1) im = im.astype(np.float32, copy=False) return im def CatImage(im_list): max_shape = np.array([im.shape for im in im_list]).max(axis=0) blob = np.zeros((len(im_list), 3, max_shape[1], max_shape[2]), dtype=np.float32) # set to mean value blob[:, 0, :, :] = 102.9801 blob[:, 1, :, :] = 115.9465 blob[:, 2, :, :] = 122.7717 for i, im in enumerate(im_list): blob[i, :, 0:im.shape[1], 0:im.shape[2]] = im return blob
6 layer/data_layer.py
import caffe import numpy as np #import data_config #import prepare_data from pre_data import prepare_data class DataLayer(caffe.Layer): def SetDataConfig(self, data_config): self._data_config = data_config def GetDataConfig(self): return self._data_config def setup(self, bottom, top): # data blob top[0].reshape(1, 3, 227, 227) #top[0].reshape(1, 3, 34, 44) # label type top[1].reshape(1, 1) def reshape(self, bootom, top): pass def forward(self, bottom, top): #(imgs, label) = self._data_config.next() path=self.GetDataConfig() (imgs,label)=prepare_data(path,128) (N, C, W, H) = imgs.shape # image data top[0].reshape(N, C, W, H) top[0].data[...] = imgs # object type label top[1].reshape(N) top[1].data[...] = label def backward(self, top, propagate_down, bottom): pass
7 layer/__init__.py
import data_layer
还有一些caffe中经典的东西没放进来。
代码和数据:
相关文章推荐
- jquery ajax 提交 FormData
- 去掉NSString中的HTML标签
- 柱状图 js工具highcharts
- jQuery UI 实例 - 对话框(Dialog)
- EL表达式<c:out>标签属性escapeXml属性
- 使用Firefox轻松调试JS
- js实现滑动解锁功能(PC+Moblie)
- JavaScript中的原型和对象机制
- window.addEventListener来解决让一个js事件执行多个函数
- javascript:;与javascript:void(0);
- jQuery的Ajax详解
- 带你学习JQuery:事件冒泡和阻止默认行为
- 【 D3.js 入门系列 --- 8 】 对话操作(事件)
- jsp自定义标签-----EL表达式中连接两个字符串
- jsp自定义标签-----EL表达式中连接两个字符串
- jquery.datatable.js与CI整合 异步加载(大数据量处理)
- js面向对象编程以及继承
- bootstrap iCheck插件 全选和获取value值的解决方法
- jquery-mousewheel 插件
- 最小的k个数