您的位置:首页 > 其它

机器学习(5)--K-means聚类(Clustering)算法

2018-02-01 22:07 477 查看
 K-means算法简述:

 1、 K-means算法是聚类(Clustering)中的经典算法,同时,也是数据挖掘的经典算法之一

 2、 该算法主要参数K,即在一些样本数据数,我们不知道每个样本是什么类,但是我们知道全部的样本分为几类或是我们想把样本分为几类,这里的几类就是K

 3、本例基本步骤

    3.1 选取前K个样本,每个样本分为一类,并设置这个K样本的坐标为中心点

    3.2 计算所有每个样本与中心点的距离,这样得到每个中心点有哪些样本

    3.3 计算每个中心点所有样本的坐标的平均值,做为新的坐标

    3.4 循环3.2步聚,直至后一次的循环每个中心点所包含的样本不再发生变化时,退出循环

本文将通过matplotlib来显示中心点的变化与每个类的变化,如果你未安装matplotlib,可以屏蔽这几句相关的内容
程序运行时会跳出matplotlib窗体,并中断程序,关闭窗体后程序会继续执行。

# -*- coding:utf-8 -*-
import matplotlib.pyplot as plt
import numpy as np

#下面两行,解决matplotlib中无法显示中文的问题
from pylab import *
mpl.rcParams['font.sans-serif'] = ['SimHei']

#数据定义,每一行为一个坐标,你也可以自行修改
data='''
1,1
2,1
4,5
6,6
5,4
3,3
2,2
3,2
5,6
1,3
3,1
6,5
''';

#整理数据
#[0,0]的意思在每个坐标后加两个值,每个样本数据均为4个值,[x,y,0,0]
#第1个为分类,
#第2个作用循环进行分类时为与前次分类比较,看是否发生变化,变化则为1,当所有样本均未变化时结束分类
data=[x.split(',')+[0,0] for x in data.split('\n')]
data=list(filter(lambda x: len(x)==4,data))
data=np.array(data).astype(np.float)
#print(data)
#取得中心点,选取前K个样本,每个样本分为一类

k=2  #在本中因为在matplotlib按类显示不同的点,只设了四个显示,分类数别太多了
centroids=data.copy()[:k,:-2]#-2表示只取样本的坐标

def draw(centroids,data,title):
plt.axis([round((np.min(data,axis=0)-1)[0])
,round((np.max(data,axis=0)+1)[0])
,round((np.min(data,axis=0)-1)[1])
,round((np.max(data,axis=0)+1)[1])]) # 用于定义X,Y轴的范围
plt.title(title)
for index,center in enumerate(centroids):
colorStr='rgby'[index:index+1] #在本中因为在matplotlib按类显示不同的点,只设了四个显示,分类数别太多了
centerData=np.array(list(filter(lambda x:x[-2]==index+1 ,data)))
if len(centerData)>0 :plt.scatter(centerData[:,0],centerData[:,1],c=colorStr)
plt.scatter(center[0],center[1],c=colorStr,marker='x')
plt.show()

runtimes=0
changePointLength=-1
while changePointLength!=0:
runtimes+=1
draw(centroids,data,'第 %d 次'%runtimes + ('  首次仅显示中心点,因为所有点的未分类' if runtimes==1 else ''))
#3.2 计算所有每个样本与中心点的距离,这样得到每个中心点有哪些样本
for dataItem in data:
distances=np.sqrt(((centroids-dataItem[:-2])**2).sum(axis=1))#计算每个点与每个中心点的距离
minDisType=np.argmin(distances)+1 #取得取小的距离的分类号
if dataItem[-2]==minDisType:
dataItem[-1]=0 #如果分类结果未发生变化
else :
dataItem[-1]=1 #如果分类结果发生变化
dataItem[-2]=minDisType

print(data)

#3.3 计算每个中心点所有样本的坐标的平均值,做为新的坐标
for index,center in enumerate(centroids):
centerData=np.array(list(filter(lambda x:x[-2]==index+1 ,data)))[:,:-2] #得到每中心点包含哪些点
centerData=centerData.mean(axis=0)
center[0]=centerData[0]
center[1]=centerData[1]
#print(data)
changePointLength=len(list(filter(lambda x:x[-1]==1 ,data))) #看有几个点的分类发生变化,如果为零,则退出循环
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息