您的位置:首页 > 其它

我的第一个svm程序:手写字识别

2015-05-03 10:42 176 查看
之前学过svm相关知识,基本原理不算复杂,今天做了一个手写字识别程序,总算验证了svm的效果。

因为只是验证效果,实现上原则是简单,使用python + libsvm + PIL(python image library)。这部分工作花了一些时间:

PIL:
http://www.pythonware.com/products/pil/
下载源码包,解压之后运行:python setup.py install即可。

max下python libsvm安装使用:http://blog.csdn.net/u012774963/article/details/14640583

libsvm python接口介绍:http://blog.csdn.net/lqhbupt/article/details/8599295

说是手写字,其实只是一到十这十个汉字,这样比较简单,而且收集的样本不太多。这十个汉字,在mac上用paintbrush前前后后画了259个80*80的png图片。图片缩放为16*16,二值化之后用一个256维的向量表示,简单粗暴。准备训练数据文件:inittraindata.py
#! /usr/bin/env python

import Image
import os

f = []
for i in range(1,11):
f.append(open('ocr_' + str(i), 'wb'))

for i in range(1,11):
for item in os.listdir(str(i)):
path = os.path.join(str(i), item)
if os.path.isfile(path) and path.endswith(".png"):
img_org = Image.open(path)
img = img_org.resize((16,16), Image.NEAREST)
pixdata = img.load()
# -1
for j in range(1,i):
line = "-1 "
for k in range(0, 256):
line += str(k + 1)
if pixdata[k / 16,k % 16][0] == 255:
line += ":0 "
else:
line += ":1 "
f[j - 1].write(line + "\n")
# -1
for j in range(i + 1, 11):
line = "-1 "
for k in range(0, 256):
line += str(k + 1)
if pixdata[k / 16, k % 16][0] == 255:
line += ":0 "
else:
line += ":1 "
f[j - 1].write(line + "\n")
# 1
line = "1 "
for k in range(0, 256):
line += str(k + 1)
if pixdata[k / 16, k % 16][0] == 255:
line += ":0 "
else:
line += ":1 "
f[i - 1].write(line + "\n")

for o in f:
o.close


训练数据并保存模型save.py:

#! /usr/bin/env python

import sys
from svmutil import *
import Image
import random

for i in range(1, 11):
y, x = svm_read_problem('./ocr_' + str(i))
#  if i == 4 or i == 3:
#    m = svm_train(y, x, '-c 10000')
#  else:
m = svm_train(y, x, '-c 3 -g 0.015625')
svm_save_model('./model_' + str(i), m)


预测predict.py:

#! /usr/bin/env python

import sys
from svmutil import *
import Image

# load
m = []
for i in range(1, 11):
m.append(svm_load_model('./model_' + str(i)))

# predict
path = sys.argv[1]
img_org = Image.open(path)
img = img_org.resize((16,16), Image.NEAREST)
pixdata = img.load()

line = "-1 "
tmpfile = open("tmpfile", "wb")
for i in range(0, 256):
line += str(i + 1)
if pixdata[i / 16, i % 16][0] == 255:
line += ":0 "
else:
line += ":1 "
tmpfile.write(line + "\n")
tmpfile.close()

max = 100.0
maxidx = -1
for i in range(1, 11):
y, x = svm_read_problem("tmpfile")
label, acc, val = svm_predict(y, x, m[i - 1])
print val[0][0]
if abs(val[0][0] - 1.0) < max:
max = abs(val[0][0] - 1.0)
maxidx = i

print "probably is: ", maxidx


使用c-svm,核函数使用RBF,参数c=3,gama=1.0/64,参数怎么选的,用的是简单粗暴的grid search,gridsearch.py:
#! /usr/bin/env python

from svmutil import *
import random

def test(y, x, c, g):
count = len(y[0])
correct_rate = 0.0
# n-fold cross-validation
for i in range(0, 10):
marr = []
tarr = []
answers = []
for k in range(count*i/10, count*(i+1)*10):
answers.append(0)

for k in range(1, 11):
# training sets
yy = []
xx = []
for j in range(0, count*i/10):
yy.append(y[k - 1][j])
xx.append(x[k - 1][j])
for j in range(count*(i + 1)/10, count):
yy.append(y[k - 1][j])
xx.append(x[k - 1][j])
m = svm_train(yy, xx, '-c ' + str(c) + ' -g ' + str(g))
marr.append(m)
yyy = []
xxx = []
for j in range(count*i/10, count*(i+1)/10):
yyy.append(y[k - 1][j])
if y[k - 1][j] == 1:
answers[j - count*i/10] = k
xxx.append(x[k - 1][j])
# test sets
tarr.append((yyy, xxx))

print answers
# predicting
correct_count = 0
for j in range(0, len(tarr[0][0])):
max = 10000.0
maxidx = -1
for k in range(1, 11):
label, acc, val = svm_predict(tarr[k - 1][0][j:j+1], tarr[k - 1][1][j:j+1], marr[k - 1])
if abs(val[0][0] - 1.0) < max:
max = abs(val[0][0] - 1.0)
maxid = k
print "probably is", maxid, " answer is", answers[j]
if answers[j] == maxid:
correct_count += 1
correct_rate += float(correct_count) / len(tarr[0][0])

correct_rate /= 10
print 'c=',c,'g=',g,'avg_correct_rate=',correct_rate
return correct_rate

def main():
yarr = []
xarr = []
for i in range(1, 11):
y, x = svm_read_problem('./ocr_' + str(i))
yarr.append(y)
xarr.append(x)

#shuffle
arr = []
for i in range(0, len(yarr[0])):
arr.append(i)
random.shuffle(arr)
print "RANDOM ARR:",arr

count = len(yarr[0])
for i in range(1, 11):
yy = []
xx = []
y = yarr[i - 1]
x = xarr[i - 1]
for j in range(0, count):
yy.append(y[arr[j]])
xx.append(x[arr[j]])
yarr[i - 1] = yy
xarr[i - 1] = xx

# grid search
maxcorrect = -1
cpos = 0
gpos = 0
for c in range(1, 16, 1):
for gg in range(0, 256, 1):
g = gg * 1.0 / 256
ret = test(yarr, xarr, c, g)
if ret > maxcorrect:
maxcorrect = ret
cpos = c
gpos = g
print "current c=",cpos,"g=",gpos,"maxcorrect=",maxcorrect

print "c=",cpos,"g=",gpos,"maxcorrect=",maxcorrect
#test(yarr, xarr, 3, 1.0 / 64)

if __name__ == '__main__':
main()


使用最优参数,估计出来的识别率在85%左右(参数调整影响只有几个点),和样本有关。如果写字比较规范,识别率应该在95%以上,可以想见用印刷体,识别率会有多高。如果歪着写字,或者大小比率比较奇怪,误识别率还是蛮高的。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: