您的位置:首页 > 其它

KNN分类器及实现

2015-06-01 21:47 218 查看

出处:/article/7618183.html
KNN即K-Nearest Neighbor,是数据挖掘中一种最简单的分类方法,即要判断某一个样本属于已知样本种类中的哪一类时,通过计算找出所有样本中与测试样本最近或者最相似的K个样本,统计这K个样本中哪一种类最多则把测试样本归位该类。如何衡量两个样本的相似度?可以用向量的p-范数来定义。
假设有两个样本X=(x1, x2, ..., xn),Y=(y1, y2, ..., yn),则他们之间的相似度可以用以下向量p-范数定义:



当p=2时即为计算X、Y的欧几里得距离。
本文将介绍用Java实现KNN分类器对Iris数据进行分类。Iris数据如下:





前面四个item是属性,最后一个是类别名,总共有三类。完整的数据集可点击这里下载。
拿到原始数据后为了测试KNN分类效果,需要在原始数据中随机抽取一部分作为测试集,另一部分作为训练集。随机抽取的方法可以用下面代码实现:

[java] view
plaincopy

/**

* 将数据集划分为训练集和测试集,随机划分

*

* @param filePath

* 数据集文件路径

* @param testCount

* 测试集个数

* @param outputPath

* 输出路径

* @throws Exception

*/

public static void splitDataSet(String filePath, int testCount,

String outputPath) throws Exception

{

BufferedWriter trainFile = new BufferedWriter(new FileWriter(new File(

outputPath + "/train.txt")));

BufferedWriter testFile = new BufferedWriter(new FileWriter(new File(

outputPath + "/test.txt")));

BufferedReader input = new BufferedReader(new FileReader(new File(

filePath)));

List<String> lines = new ArrayList<String>();

String line = null;

//将所有数据读取到一个List里

while ((line = input.readLine()) != null)

lines.add(line);

//遍历一次List,每次产生一个随机序号,将该随机序号和当前序号内容进行交换

for (int i = 0; i < lines.size(); i++)

{

int ran = (int) (Math.random() * lines.size());

String temp = lines.get(i);

lines.set(i, lines.get(ran));

lines.set(ran, temp);

}

int i = 0;

//将指定数目的测试集写进test.txt中

for (; i < testCount; i++)

{

testFile.write(lines.get(i) + "\n");

testFile.flush();

}

//剩余的写进train.txt中

for (; i < lines.size(); i++)

{

trainFile.write(lines.get(i) + "\n");

trainFile.flush();

}

testFile.close();

trainFile.close();

}

调用这个方法后就可以得到train.txt和test.txt两份数据了。

接下来将数据读入:

[java] view
plaincopy

/**

* 根据文件生成训练集,注意:程序将以第一个出现的非数字的属性作为类别名称

*

* @param fileName

* 文件名

* @param sep

* 分隔符

* @return

* @throws Exception

*/

public List<DataNode> getDataList(String fileName, String sep)

throws Exception

{

List<DataNode> list = new ArrayList<DataNode>();

BufferedReader br = new BufferedReader(new FileReader(

new File(fileName)));

String line = null;

while ((line = br.readLine()) != null)

{

String splits[] = line.split(sep);

//DataNode类用于保存数据属性和数据类别

DataNode node = new DataNode();

int i = 0;

for (; i < splits.length; i++)

{

try

{

node.addAttrib(Float.valueOf(splits[i]));

} catch (NumberFormatException e)

{

// 非数字,则为类别名称,将类别映射为数字

if (!mTypes.containsKey(splits[i]))

{

mTypes.put(splits[i], mTypeCount);

mTypeCount++;

}

node.setType(mTypes.get(splits[i]));

list.add(node);

}

}

}

return list;

}

对于testList中的每一个样本,均与所有trainList中的样本进行计算,取出最接近的K个样本并返回:

KnnClassifier.java

[java] view
plaincopy

package com.jingchen.knn;

import java.util.List;

/**

* @author chenjing

*

*/

public class KnnClassifier

{

//k个近邻节点

private int k;

private KNode[] mNearestK;

private List<DataNode> mTrainData;

public KnnClassifier(int k, List<DataNode> trainList)

{

mTrainData = trainList;

this.k = k;

mNearestK = new KNode[k];

for (int i = 0; i < k; i++)

mNearestK[i] = new KNode();

}

public void setK(int k){

this.k = k;

mNearestK = new KNode[k];

for (int i = 0; i < k; i++)

mNearestK[i] = new KNode();

}

private void train(DataNode test, float p)

{

for (int i = 0; i < mTrainData.size(); i++)

{

putNode(getSim(test, mTrainData.get(i), p));

}

}

/**

* 将新计算出来的节点与k个近邻节点比较,如果比其中之一小则插入

* @param node

*/

private void putNode(KNode node)

{

for (int i = 0; i < k; i++)

{

if (node.getD() < mNearestK[i].getD())

{

for (int j = k - 1; j > i; j--)

mNearestK[j] = mNearestK[j - 1];

mNearestK[i] = node;

break;

}

}

}

/**

* 获取相似度并封装成一个KNode类型返回

* @param test

* @param trainNode

* @param p

* @return

*/

private KNode getSim(DataNode test, DataNode trainNode, float p)

{

List<Float> list1 = test.getAttribs();

List<Float> list2 = trainNode.getAttribs();

float d = 0;

for (int i = 0; i < list1.size(); i++)

d += Math.pow(

Math.abs(list1.get(i).floatValue() - list2.get(i).floatValue()), p);

d = (float) Math.pow(d, 1/p);

KNode node = new KNode(d, trainNode.getType());

return node;

}

private void reset()

{

for (int i = 0; i < k; i++)

mNearestK[i].reset();

}

/**

* 返回K个近邻节点

* @param test

* @param p

* @return

*/

public KNode[] getKNN(DataNode test, float p)

{

reset();

train(test, p);

return mNearestK;

}

}

main方法:

[java] view
plaincopy

public static void main(String[] args) throws Exception

{

DataUtil util = DataUtil.getInstance();

//获得训练集和测试集

List<DataNode> trainList = util.getDataList("E:/train.txt", ",");

List<DataNode> testList = util.getDataList("E:/test.txt", ",");

int K = BASE_K;

KnnClassifier classifier = new KnnClassifier(K, trainList);

BufferedWriter output = new BufferedWriter(new FileWriter(new File(

"E:/output.txt")));

int typeCount = util.getTypeCount();

int[] count = new int[typeCount];

for (int i = 0; i < testList.size();)

{

for (int m = 0; m < typeCount; m++)

count[m] = 0;

DataNode test = testList.get(i);

classifier.setK(K);

KNode[] nodes = classifier.getKNN(test, 2);

for (int j = 0; j < nodes.length; j++)

count[nodes[j].getType()]++;

int type = -1;

int max = -1;

for (int j = 0; j < typeCount; j++)

{

if (count[j] > max)

{

max = count[j];

type = j;

} else if (count[j] == max)

{

// 存在两个类型分个数相同,无法判断属于哪个类型,增加K的值继续从该节点开始

type = -1;

K++;

break;

}

}

if (type == -1)

continue;

else

{

i++;

K = BASE_K;

}

//将分类结果写入文件

List<Float> attribs = test.getAttribs();

for (int n = 0; n < attribs.size(); n++)

{

output.write(attribs.get(n) + ",");

output.flush();

}

output.write(util.getTypeName(type) + "\n");

output.flush();

}

output.close();

}

经测试,KNN对Iris数据集分类准确率基本都在90+%以上,此分类方法也比较直观。数据集及完整的项目代码可以从这里下载:点击下载
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: