您的位置:首页 > 其它

转: Kaggle入门模板:以手写识别Digit Recognizer为例

2017-12-23 16:31 411 查看
首先本文参考了点击打开链接 这篇博客,然后可能时间有点久远,Kaggle的这道题给的数据文档和之前的不一样了,以及还有一些注意点这篇文章里没有突出。因此这里重新做个总结,希望大家能早点入个门。

这里我使用的sklearn中的支持向量机来解决手写识别问题。这里的svm是可以解决多分类问题的。核函数使用的是高斯核(rbf),松弛变量c选择的是5.

kaggle这道题一共提供了3个文件:train.csv,test.csv,sample_submission.csv 。 分别表示训练集,测试集,提交样例。

下面上python代码。本人的macbook pro16,运行时间为575秒。svm的准确率在这个问题上可能不及knn,但是运行的效率要比knn高了许多。。。

[python] view
plain copy

#!/usr/bin/python    

# -*- coding: utf-8 -*-    

    

from numpy import *    

from sklearn import svm      

import csv     

import datetime  

  

#把数组中的字符串转换成整数  

def toInt(array):   

    array=mat(array)    

    m,n=shape(array)    

    #使用xrange不会生成list,性能要优于range  

    for i in xrange(m):    

        for j in xrange(n):    

                array[i,j]=int(array[i,j])    

    return array    

  

#把大于0的数都置为1  

def nomalizing(array):    

    m,n=shape(array)    

    for i in xrange(m):    

        for j in xrange(n):    

            if array[i,j]!="0":  #注意原csv文件中的数字也是字符串  

                array[i,j]=1    

            else:  

                array[i,j]=0  

    return array    

  

def loadTrainData():    

    l=[]    

    with open('train.csv') as file:    

         lines=csv.reader(file)    

         for line in lines:    

             l.append(line) #42001*785    

    l.remove(l[0])  #移除第0行,第0行是数据列名  

    l=array(l)  #将l由list型转化为numpy下的array型  

    label=l[:,0]  #label赋值为l的第0列  

    data=l[:,1:]  #data赋值为l的第1至最后一列  

    return nomalizing(data),toInt(label)   

  

def loadTestData():    

    l=[]    

    with open('test.csv') as file:    

         lines=csv.reader(file)   

         for line in lines:    

             l.append(line)    

    l.remove(l[0])    

    data=array(l)    

    return nomalizing(data)    

  

def saveResult(result,csvName):    

    with open(csvName,'wb') as myFile:        

        myWriter=csv.writer(myFile)   

        num = 1   

        arr=[]  

        arr.append("ImageId")  

        arr.append("Label")  

        myWriter.writerow(arr)  #先将列名插入第0行  

        for i in result:    

            tmp=[]   

            tmp.append(num)  

            num = num + 1   

            tmp.append(int(i))  ##不能是浮点数    

            myWriter.writerow(tmp)    

      

def svcClassify(trainData,trainLabel,testData):     

    svcClf=svm.SVC(C=5.0) #default:C=1.0,kernel = 'rbf'. you can try kernel:‘linear’, ‘poly’, ‘rbf’, ‘sigmoid’, ‘precomputed’      

    svcClf.fit(trainData,ravel(trainLabel))    

    testLabel=svcClf.predict(testData)    

    saveResult(testLabel,'sklearn_SVC_C=5.0_Result.csv')    

    return testLabel    

  

def main():    

    starttime = datetime.datetime.now()  

    trainData,trainLabel=loadTrainData()    

    print "训练集读取完毕"  

    testData=loadTestData()     

    print "测试集读取完毕"  

    svcClassify(trainData,trainLabel,testData)  

    endtime = datetime.datetime.now()  

    print "预测结束--程序总运行时间:"+str((endtime - starttime).seconds)+"秒"  

  

main() #主函数  

ps:本人一开始在kaggle上提交结果,总是返回的准确率为0.00000,后来用文本编辑器打开了csv,才发现自己生成的label都是浮点数,而在excel中看不出来,坑。

kaggle提交注意事项:

每道题目一天最多交5次,大家珍惜每天的提交机会
提交的csv要严格遵循sample_submission.csv中的格式,也就是在提交文件中第一行的列名也是需要加的,且列名不能出错。
提交的数据一定要弄清是整数还是浮点数。否则提交后是会被判断为预测错误的。
       kaggle这个平台真心不错,让我找回了codeforces的感觉,感觉找到了一个很好的锻炼动手能力的平台,希望大家能经常做做练习~
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: