机器学习实战K-近邻算法
2017-09-24 19:53
169 查看
今天开始学习机器学习,第一章是K-近邻算法,有不对的地方请指正
大概总结一下近邻算法写分类器步骤:
1. 计算测试数据与已知数据的特征值的距离,离得越近越相似
2. 取距离最近的K个已知数据的所属分类
3. 最后统计K个值的分类分别出现的概率,返回最多的一个属性,即为测试数据的所属分类
4. 至于怎么把文本转换成numpy的类型,需要学习numpy模块的相关知识,附上
numpy学习连接 http://old.sebug.net/paper/books/scipydoc/numpy_intro.html
大概总结一下近邻算法写分类器步骤:
1. 计算测试数据与已知数据的特征值的距离,离得越近越相似
2. 取距离最近的K个已知数据的所属分类
3. 最后统计K个值的分类分别出现的概率,返回最多的一个属性,即为测试数据的所属分类
4. 至于怎么把文本转换成numpy的类型,需要学习numpy模块的相关知识,附上
numpy学习连接 http://old.sebug.net/paper/books/scipydoc/numpy_intro.html
#-*- coding:utf-8 *-*- from numpy import * import operator #计算模块 import matplotlib import matplotlib.pyplot as plt import time import random from mpl_toolkits.mplot3d import Axes3D from os import listdir import time def createDataSet(): group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]]) labels = ['A','A','B','B'] return group,labels #A,B分类 def classify0(inX, dataSet, labels, k): dataSetSize = dataSet.shape[0] diffMat = tile(inX,(dataSetSize,1)) - dataSet #tile函数把inx复制datasetsize行1列 sqDiffMat = diffMat**2 #print "sqDiffMat : ",sqDiffMat sqDistance = sqDiffMat.sum(axis = 1) distance = sqDistance**0.5 #print "distance : ",distance sortedDistIndicies = distance.argsort() #返回从小到大的元素的下标,比如[1 3 2 4].argsort()返回[0 2 1 3] #print "****",sortedDistIndicies classCount = {} for i in range(k): voteIlabel = labels[sortedDistIndicies[i]] #统计各个现有值所属的特征向量 #print sortedDistIndicies[i],voteIlabel classCount[voteIlabel] = classCount.get(voteIlabel,0)+1 #统计各个特征向量出现的次数 sortedClassCount = sorted(classCount.iteritems(),key = operator.itemgetter(1),reverse = True) #operator.itemgetter()从小到大排序 #print "sortedClassCount : ",sortedClassCount return sortedClassCount[0][0] group,labels = createDataSet() #print classify0([0,0], group, labels, 3) # # a = [('b',2),('a',1),('c',0)] # a=[('b',2),('a',2),('a',1),('c',0)] # b = sorted(a,key = operator.itemgetter(0)) #优先根据第一个元素排序 # print b # b = sorted(a,key = operator.itemgetter(1)) #优先根据第二个元素排序 # print b # b = sorted(a,key = operator.itemgetter(1,0)) #优先根据第二个元素排序,当第二个元素相等的情况下根据第一个元素排序 # print b #解析数据 def file2matrix(filename): with open(filename) as f: lines = f.readlines() matrixNumber = len(lines) print 'the all lines is :',matrixNumber #matrix = zeros((matrixNumber,3),dtype = 'int') #生成空的n行3列的矩阵 matrix = zeros((matrixNumber,2)) vector = [] index = 0 #矩阵索引 for line in lines: line = line.strip() data = line.split("\t") matrix[index:] = data[0:2] #把提取出来的复制到矩阵里面 vector.append(int((data[-1]))) #最后一个特征值作为特征向量 index+=1 return matrix,vector #生成文本数据 def createdata(filename): with open(filename,'w') as f: for i in range(1000): r1 = int(random.random()*1000) r2 = 0 if(0<=r1<=200): r2 = 1 if(200<r1<=400): r2 = 2 if(400<r1<=600): r2 = 3 if(600<r1<=800): r2 = 4 if(800<r1<=1000): r2 = 5 r1 = str(r1) r2 = str(r2) #r2 = str(int(random.random()*10)) r3 = str(int(random.random()*10)) f.writelines(r3+'\t'+r1+'\t'+r2+'\n') #createdata(r'D:\test_packages\knntest.txt') ''' datat,labels = file2matrix(r'D:\test_packages\knntest.txt') print datat # print datat[:,1] #纵向的第二列 # print datat[:][1] #横向的第二列 print labels fig = plt.figure() #生成容器 plt.title('favorite table data') ax = fig.add_subplot(1,1,1,projection='3d') #3D模型 ax.scatter(datat[:,0],datat[:,1],datat[:,2],array(labels),array(labels),array(labels)) #使用datat的第二列和第三列作为X轴和Y轴的值 ax.legend() plt.show() fig = plt.figure() ax = fig.add_subplot(1,1,1) #把容器划分为1行1列,图像画在第一格,背景颜色为axisbg = ‘’ ax.scatter(datat[:,1],datat[:,2],array(labels),array(labels)) #使用datat的第二列和第三列作为X轴和Y轴的值 #ax.grid(True) #是否显示网格 # plt.show() plt.show() ''' #归一化,(old-min)/(max-min) def autoNormal(dataSet): maxVals = dataSet.max(0) #纵向找到每一个样本的最大特征值 minVals = dataSet.min(0) ranges = maxVals - minVals #计算差值 normalValue = zeros(shape(dataSet)) m = dataSet.shape[0] normalValue = dataSet - tile(minVals,(m,1)) #计算(old-min) normalValue = normalValue/tile(ranges,(m,1)) return normalValue,ranges,minVals #归一化特征值之后 datat,labels = file2matrix(r'D:\test_packages\knntest.txt') normalValue,ranges,minVals = autoNormal(datat) print normalValue fig = plt.figure() ax = fig.add_subplot(1,1,1) #把容器划分为1行1列,图像画在第一格,背景颜色为axisbg = ‘’ ax.scatter(normalValue[:,0],normalValue[:,1],array(labels),array(labels)) #使用datat的第二列和第三列作为X轴和Y轴的值 #ax.grid(True) #是否显示网格 # plt.show() plt.show() #约会网站测试函数 def datinggTest(): datat,labels = file2matrix(r'D:\test_packages\knntest.txt') normal,ranges,minvals = autoNormal(datat) testData = 0.5 #10%用来测试,90%用来训练 testNumber = normal.shape[0] #总行数 numberTestValues = int(testNumber*testData) #测试行数 error = 0.0 for i in range(numberTestValues): labelValue = classify0(normal[i,:], normal[numberTestValues:testNumber,:], labels[numberTestValues:testNumber], 3) if (labelValue != labels[i]): error+=1.0 print "this time is error the error is %s, the right is %s"%(labelValue,labels[i]) else: print "all right ,the number is %s, the right is %s"%(labelValue,labels[i]) error_result = ((error/float(numberTestValues))) print "your error_result is %s"%(error_result) print 'error is :',error datinggTest() #把二进制文件转化为np.array def img2Vector(filename): with open(filename) as f: vector = zeros((1,1024)) for i in range(32): line = f.readline() for j in range(32): vector[0,32*i+j] = line[j] return vector vector = img2Vector(r'D:\test_packages\trainingDigits\0_0.txt') print vector[0,11:17] #手写数字识别系统测试代码 def handwritingClassTest(): startTime = time.ctime() handLabels = [] trainFile = listdir(r'D:\test_packages\trainingDigits') m = len(trainFile) trainMat = zeros((m,1024)) for i in range(m): fileName = trainFile[i] file = fileName.split('.')[0] classNumber = file.split('_')[0] handLabels.append(classNumber) trainMat[i,:] = img2Vector(r'D:\test_packages\trainingDigits\%s'%fileName) testFiles = listdir(r'D:\test_packages\testDigits') nTest = len(testFiles) error = 0.0 for i in range(nTest): fileName = testFiles[i] file = fileName.split('.')[0] classNumber = file.split('_')[0] testMat = img2Vector(r'D:\test_packages\testDigits\%s'%fileName) testLabels = classify0(testMat, trainMat, handLabels, 3) if (testLabels != classNumber): error+=1.0 print 'error , error number is %s, the right number is %s'%(testLabels,classNumber) else: print 'right' error = error/float(nTest) stopTime = time.ctime() print 'all right ,the error_result is %s'%(error) print 'the process start at %s'%(startTime) print 'the process stop at %s'%(stopTime) handwritingClassTest()
相关文章推荐
- 机器学习实战K-近邻算法遇到的几个错误
- 机器学习实战k-近邻算法(kNN)应用之改进婚恋网站配对效果代码解
- machine_learning-knn算法详解(近邻算法)
- 机器学习之二:K-近邻算法(KNN)
- 机器学习实战之k-近邻算法(5)--- 完整版约会网站数据分类
- 机器学习系列(二)k-近邻算法(1)
- K-近邻算法学习心得体会
- 《机器学习实战》----k-近邻算法
- k-近邻算法
- 机器学习笔记(7)---K-近邻算法(5)---使用K近邻算法检测异常操作之二
- K-近邻算法简介
- K-近邻算法
- 机器学习实战学习笔记(一):K-近邻算法
- 机器学习:kNN近邻算法
- K-近邻算法(KNN)
- k-近邻算法
- 机器学习实战 k-近邻算法(kNN)
- 《Python机器学习实战》第一章读书笔记:k-近邻算法
- 机器学习实战(k-近邻算法)
- K近邻改进约会网站(五):使用算法进行预测