您的位置:首页 > Web前端

Caffe框架,训练model并测试数据

2015-10-26 10:55 302 查看
1. 训练model

#!/usr/bin/env sh

./build/tools/caffe train --solver=examples/focal_length/focal_solver.prototxt


2. 测试数据

import caffe
from caffe.proto import caffe_pb2
import numpy as np
import cv2

run_mode = 'gpu'
deploy_file = 'focal_deploy.prototxt'
weight_file = 'focal_iter_5000.caffemodel'
data_path = '/home/lei/project/caffe/caffe-master/data/focal_length/test/'
list_file = 'list.txt'

# init the caffenet
if run_mode == 'gpu':
caffe.set_device(0)
caffe.set_mode_gpu()
elif run_mode == 'cpu':
caffe.set_mode_cpu()
net = caffe.Net(deploy_file, weight_file, caffe.TEST)

# mean array
mean_file      = '/home/lei/project/caffe/caffe-master/data/ilsvrc12/imagenet_mean.binaryproto'
mean_blobproto = caffe_pb2.BlobProto()
fid            = open(mean_file, 'rb')
mean_blobproto.ParseFromString( fid.read() )
mean_array     = caffe.io.blobproto_to_array( mean_blobproto )
mean_array     = mean_array.astype( np.float32 )
mean_array     = np.zeros(mean_array.shape, mean_array.dtype)

with open(data_path + list_file) as fid:
for key in fid.readlines():
filename = ( data_path + key ).strip()

img     = cv2.imread(filename)
cv2.imshow('name', img)
cv2.waitKey(500)
img     = img.astype( np.float32 )
data_in = img.reshape( (1,256,256,3) )
data_in = data_in.transpose((0,3,1,2))
data_in = data_in - mean_array
data_in = data_in * 0.00390625

netout = net.forward_all(data=data_in)
guess  = netout['prob'].argmax(axis=1) # the type is np.int64
prob   = netout['prob'].max(axis=1)

print '{}: guess: {}, prob: {}'.format(key, guess[0], prob[0])

# /home/lei/project/caffe/caffe-master/examples/focal_length
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: