您的位置:首页 > 其它

Spark:Scala实现KMeans算法

2015-03-25 09:15 471 查看
1 什么是KMeans算法

K-Means算法是一种cluster analysis的算法,其主要是来计算数据聚集的算法,主要通过不断地取离种子点最近均值的算法。

具体来说,通过输入聚类个数k,以及包含 n个数据对象的数据库,输出满足方差最小标准的k个聚类。

2 k-means 算法基本步骤

(1) 从 n个数据对象任意选择 k 个对象作为初始聚类中心;

(2) 根据每个聚类对象的均值(中心对象),计算每个对象与这些中心对象的距离;并根据最小距离重新对相应对象进行划分;

(3) 重新计算每个(有变化)聚类的均值(中心对象);

(4) 计算标准测度函数,当满足一定条件,如函数收敛时,则算法终止;如果条件不满足则回到步骤(2)。

算法的时间复杂度上界为O(n*k*t), 其中t是迭代次数,n个数据对象划分为 k个聚类。

3 MLlib实现KMeans

以MLlib实现KMeans算法,利用MLlib KMeans训练出来的模型,可以对新的数据作出分类预测,具体见代码和输出结果。

Scala代码:

1 package com.hq

2

3 import org.apache.spark.mllib.clustering.KMeans

4 import org.apache.spark.mllib.linalg.Vectors

5 import org.apache.spark.{SparkContext,
SparkConf}

6

7 object KMeansTest {

8 def main(args: Array[String]) {

9 if (args.length < 1) {

10 System.err.println("Usage: <file>")

11 System.exit(1)

12 }

13

14 val conf = new SparkConf()

15 val sc = new SparkContext(conf)

16 val data = sc.textFile(args(0))

17 val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble)))

18 val numClusters = 2

19 val numIterations = 20

20 val clusters = KMeans.train(parsedData,numClusters,numIterations)

21

22 println("------Predict the existing line in
the analyzed data file: "+args(0))

23 println("Vector 1.0 2.1 3.8 belongs to clustering "+ clusters.predict(Vectors.dense("1.0 2.1 3.8".split(' ').map(_.toDouble))))

24 println("Vector 5.6 7.6 8.9 belongs to clustering "+ clusters.predict(Vectors.dense("5.6 7.6 8.9".split(' ').map(_.toDouble))))

25 println("Vector 3.2 3.3 6.6 belongs to clustering "+ clusters.predict(Vectors.dense("3.2 3.3 6.6".split(' ').map(_.toDouble))))

26 println("Vector 8.1 9.2 9.3 belongs to clustering "+ clusters.predict(Vectors.dense("8.1 9.2 9.3".split(' ').map(_.toDouble))))

27 println("Vector 6.2 6.5 7.3 belongs to clustering "+ clusters.predict(Vectors.dense("6.2 6.5 7.3".split(' ').map(_.toDouble))))

28

29 println("-------Predict the non-existent line in
the analyzed data file: ----------------")

30 println("Vector 1.1 2.2 3.9 belongs to clustering "+ clusters.predict(Vectors.dense("1.1 2.2 3.9".split(' ').map(_.toDouble))))

31 println("Vector 5.5 7.5 8.8 belongs to clustering "+ clusters.predict(Vectors.dense("5.5 7.5 8.8".split(' ').map(_.toDouble))))

32

33 println("-------Evaluate clustering by computing Within Set Sum of Squared Errors:-----")

34 val wssse = clusters.computeCost(parsedData)

35 println("Within Set Sum of Squared Errors = "+ wssse)

36 sc.stop()

37 }

38 }

复制代码

4 以Spark集群standalone方式运行

①在IDEA打成jar包(如果忘记了,参见Spark:用Scala和Java实现WordCount),上传到用户目录下/home/ebupt/test/kmeans.jar

②准备训练样本数据:hdfs://eb170:8020/user/ebupt/kmeansData,内容如下

[ebupt@eb170 ~]$ hadoop fs -cat ./kmeansData

1.0 2.1 3.8

5.6 7.6 8.9

3.2 3.3 6.6

8.1 9.2 9.3

6.2 6.5 7.3

复制代码

③spark-submit提交运行

[ebupt@eb174 test]$ spark-submit
--master spark://eb174:7077 --name KmeansWithMLib --class com.hq.KMeansTest --executor-memory 2G --total-executor-cores 4 ~/test/kmeans.jar hdfs://eb170:8020/user/ebupt/kmeansData

输出结果摘要:

1 ------Predict the existing line in the analyzed data file: hdfs://eb170:8020/user/ebupt/kmeansData

2 Vector 1.0 2.1 3.8 belongs to clustering 0

3 Vector 5.6 7.6 8.9 belongs to clustering 1

4 Vector 3.2 3.3 6.6 belongs to clustering 0

5 Vector 8.1 9.2 9.3 belongs to clustering 1

6 Vector 6.2 6.5 7.3 belongs to clustering 1

7 -------Predict the non-existent line in
the analyzed data file: ----------------

8 Vector 1.1 2.2 3.9 belongs to clustering 0

9 Vector 5.5 7.5 8.8 belongs to clustering 1

10 -------Evaluate clustering by computing Within Set Sum of Squared Errors:-----

11 Within Set Sum of Squared Errors = 16.393333333333388

复制代码

5 Spark总结

本文主要介绍了MLbase如何实现机器学习算法,简单介绍了MLBase的设计思想。

与其它机器学习系统Weka、mahout不同:

MLbase是分布式的,Weka是单机的。

Mlbase是自动化的,Weka和mahout都需要使用者具备机器学习技能,来选择自己想要的算法和参数来做处理。

MLbase提供了不同抽象程度的接口,可以扩充ML算法。

参考文献:http://www.aboutyun.com/thread-10817-1-1.html
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: