K近临算法(KNN)
2016-03-03 17:15
417 查看
什么是K近临算法
K近临算法是基于实例的学习算法,俗称KNN什么是基于实例
基于实例的学习算法只是简单的把样例存储起来。把这些实例中泛化的工作推迟到必须分类的时候。每当学习器遇到一个新的实例时,它将实时分析这个实例与以前存储的实例的关系,并据此把一个目标函数值赋到新的实例。K近临原理介绍
k近临的思路是找K个与目标最相似的样本,认为目标就属于K个样本中最多频次的类别。分类
对未知类别属性数据中每个点的分类过程:
1. 计算未知点到训练数据点的距离
2. 对这些训练数据点递增排序
3. 选K个最近的点
4. 确定K个点所在类别的频率
5. 返回K个点出现频率最高的类别作为未知点类别
其中的距离计算一般为欧式距离,或者城市距离(具体依实际情况定)
如上图所示,如果K=3就是实线所圈的点,那么未知点就会被分成红色三角形的类别上。如果K=5那就是虚线所圈的点,未知点就会被分成蓝色四方块的类别上。
回归
1. 计算未知点到训练数据点的距离
2. 对这些训练数据点递增排序
3. 选K个最近的点
4. 将K个最近点的属性均值赋值给目标点(得到目标点的属性)
当然取均值只是一种方法,还可以按样本点与目标点的距离进行加权等方法进行属性值赋值。
K近临的优缺点
优点
抗噪声能力强:K近临只选最近的K个点,因些一些离群点完全不影响算法的处理效果。模型简单,分类效果相当不错。kaggle的手势识别,朴素的算法就能达到96.5%的准确率
由于KNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方 法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,KNN方法较其他方法更为适合。
缺点
维度灾难:k-近邻算法在计算距离的时候,考虑实例的所有属性。可能分类仅由某几个属性决定,这中情况下属性的相似性度量会误导k-近邻算法的分类。解决办法:(1)属性加权;(2)剔除不相关的属性。
效率超慢:对于每个目标点都需要计算它到所有样本点的距离
当样本维度为NN特征维度为DD,那么复杂度就是呈O(ND)O_\left(ND\right)
1) Ko与Seo提出一算法TCFP(text categorization using feature projection),尝试利用特征投影法(en:feature projection)来降低与分类无关的特征对于系统的影响,并借此提升系统效能,其实实验结果显示其分类效果与k最近邻居法相近,但其运算所需时间仅需k最近邻居法运算时间的五十分之一。
2)所有通常把KNN的数据集训练成KD树,构建过程很快,甚至不用计算D维欧氏距离,而搜索速度高达O(Dlog(N))O_\left(Dlog(N)\right)。不过当特征维度过高时会产生维度灾难,导致KD树的效率接近O(ND)O_\left(ND\right).按经验是D>20时,最好改用更高效的ball-tree效率O(Dlog(N))O_\left(Dlog(N)\right).
这个我用kaggle的手势识别数据测过,42000个训练数据,782个特征,28000测试数据,朴素的KNN要跑8个小时左右,ball_tree, 几分钟或几十分钟内
当样本不均匀时,那么有可能K个样本中某一类的样本都会占主导,从而影响分类效果,增加K影响就可能越明显。这个可以考虑使用距离加权来解决,加大近距离点的影响力。具体方法很多,需要考虑实际效果。如按距离的+1对数倒数
K的选择
一般按经验k=5.如果选择较小的K值,就相当于用较小的领域中的训练实例进行预测,“学习”近似误差会减小,只有与输入实例较近或相似的训练实例才会对预测结果起作用,与此同时带来的问题是“学习”的估计误差会增大,换句话说,K值的减小就意味着整体模型变得复杂,容易发生过拟合;
如果选择较大的K值,就相当于用较大领域中的训练实例进行预测,其优点是可以减少学习的估计误差,但缺点是学习的近似误差会增大。这时候,与输入实例较远(不相似的)训练实例也会对预测器作用,使预测发生错误,且K值的增大就意味着整体的模型变得简单。
K=N,则完全不足取,因为此时无论输入实例是什么,都只是简单的预测它属于在训练实例中最多的累,模型过于简单,忽略了训练实例中大量有用信息。
但我有个想法,就是可以用优化算法来确定这个K,也就假设这个K和KNN的模型效果呈线性关系。这样就可以选用传统的优化算法来计算选择一个合适的K了。
优化算法有:
1. 遗传算法
2. 模拟退火算法
我建议使用遗传算法。评价就用KNN算法的准确率,采用交叉验证法(就是一部分样本做训练集,一部分做测试集)来计算准确率。
sklearn代码
简单贴个kaggle的数字识别的代码吧,还有其它代码请自动忽略吧。#!/usr/bin/env python # coding=utf-8 ################################################################# # File: RandomForestDR.py # Author: Neal Gavin # Email: nealgavin@126.com # Created Time: 2016/03/02 15:49:29 # Saying: Fight for freedom ^-^ ! ################################################################# import numpy as np import time from sklearn.ensemble import RandomForestClassifier from sklearn.cross_validation import train_test_split from sklearn.naive_bayes import GaussianNB from sklearn.linear_model import SGDClassifier from sklearn.svm import SVC from sklearn.tree import DecisionTreeClassifier from sklearn.metrics import precision_recall_curve from sklearn.neighbors import KNeighborsClassifier class DigitRC(object): """DigitRC""" def readData(self): """诗入训练数据""" self.time_start = time.time() trainData = np.loadtxt('./train.csv', delimiter = ',', skiprows = 1) testData = np.loadtxt('./test.csv', delimiter = ',', skiprows = 1) print type(trainData) X_train = np.array([x[1:] for x in trainData]) Y_train = np.array([x[0] for x in trainData]) self.printTime() print 'split data' self.X_train, self.X_valid, self.Y_train, self.Y_valid = train_test_split(X_train, Y_train, test_size = 0.1, random_state = 42) self.printTime() print 'train', np.shape(self.X_train), np.shape(self.Y_train) print 'valid', np.shape(self.X_valid), np.shape(self.Y_valid) self.X_test = testData def printTime(self): """time""" time_spend = time.time() - self.time_start print 'time:', time.strftime('%H:%M:%S', time.gmtime(time_spend)) def Train(self, method = 'RandomForestClassifier'): """训练""" if method == 'RandomForestClassifier': #准确率 算法 #0.964 随机森林 rf = RandomForestClassifier(n_estimators = 100) elif method == 'NaiveBayes': #0.567 朴素贝叶斯 rf = GaussianNB() elif method == 'SGDClassifier': #0.868 SVM rf = SGDClassifier(loss = 'hinge', penalty = 'l2', alpha = 1e-3, n_iter = 5) elif method == 'SVC': #0.11 速度巨慢 rf = SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, decision_function_shape=None, degree=3, gamma='auto', kernel='linear', max_iter=-1, probability=False, random_state=None, shrinking=True, tol=0.001, verbose=False) elif method == 'DecisionTreeClassifier': #0.867 决策树 rf = DecisionTreeClassifier(criterion = 'entropy') elif method == 'KNeighborsClassifier': #0.965 KNN 4分钟 4000+ rf = KNeighborsClassifier(algorithm = 'auto') print 'train_begin:', method self.rf_model = rf.fit(self.X_train, self.Y_train) print 'train_end' self.printTime() print u'准确率评估' score = rf.score(self.X_valid, self.Y_valid) print '准确率:', score self.printTime() def Test(self): """测试""" print 'test start' testResult = self.rf_model.predict(self.X_test) self.testResult = testResult print 'test end' # #准确率召回率 # precision, recall, thresholds = precision_recall_curve(self.Y_train, self.rf_model.predict(self.X_train)) # print 'precision:', precision, 'recall', recall, 'thresholds', thresholds self.printTime() def SaveResult(self): """保存结果""" pred = [[inx+1, x] for inx, x in enumerate(self.testResult)] np.savetxt('./myans.csv', pred, delimiter = ',', fmt = '%d,%d', header='ImageId,Label') self.printTime() print 'Done' def process(self): self.readData() # self.Train('NaiveBayes') # self.Train('SGDClassifier') self.Train('SVC') # self.Train('DecisionTreeClassifier') # self.Train('KNeighborsClassifier') self.Test() self.SaveResult() if __name__ == '__main__': tt = DigitRC() tt.process()
kd树介绍:/article/1350686.html
相关文章推荐
- 量化交易 ,金融策略的基础!
- 【java学习笔记】for增强循环
- HDU 1846 Brave Game 巴士博弈
- iPhone 6 / 6 Plus 设计·适配方案
- 不用Ubuntu,自己动手下载Android源码
- GridView的属性
- 线段树求逆序数方法 HDU1394&&POJ2299
- 正则表达式(包含与不包含)
- mysql合并相同字段,不同的拼接在起后
- webUploader上传组件 实际运用小结
- Tomcat各种内存溢出解决办法总结
- JAVA的StringBuffer类
- linux下Drools6.3.0规则引擎的安装配置
- Java对象克隆——浅克隆和深克隆的区别
- tomcat部署web项目的3中方法
- 初识spring之quartz定时调度
- 用workspace管理工程,并解决多静态库依赖
- java 字符串缓冲池 String缓冲池
- Android PowerImageView实现,可以播放动画的强大ImageView
- 信息检索的评价指标