您的位置:首页 > 其它

KMeans biKMeans

2013-08-17 17:08 190 查看
KMeans 和 biKMeans都容易陷入局部最优,biKMeans的效果也不好 

所以需要多次运行 找SSE最小的那个

# coding=utf-8
from numpy import *

#Kmeans算法:可能收敛到局部最小值,在大规模数据集上收敛的比较慢
#Kmeans是发现给定数据集K个簇的算法,k是由用户自己指定的
#1. 随机选择K个初始值作为质心,然后将数据集中每个点分配到一个簇中,然后每个簇的质心更新为该簇所有点
#的平均值
#加载数据集
def loadDataSet(fileName):
dataMat=[]
fr=open(fileName)
for line in fr.readlines():
curLine=line.strip().split('\t')
fltLine=map(float,curLine)
dataMat.append(fltLine)
return dataMat
#定义距离公式,这里使用欧式距离
def distEclud(vecA,vecB):
return sqrt(sum(power(vecA-vecB,2)))
#随机生成k个质心
def randCent(dataSet,k):
n=shape(dataSet)[1]#列
centroids=mat(zeros((k,n)))#k个质心 每个质心都是1*n维的
for j in range(n):
minJ=min(dataSet[:,j])#第j列的最小值
rangeJ=float(max(dataSet[:,j])-minJ)#第j列的最大值和最小值的差
centroids[:,j]=minJ+rangeJ*random.rand(k,1)
return centroids
#kMeans算法 k 簇的数目 dataSet:原始数据集,distMeas 距离计算函数,createCent k个质心生成函数
def KMeans(dataSet,k,distMeas=distEclud,createCent=randCent):
m=shape(dataSet)[0]#dataSet的样本个数
clusterAssment=mat(zeros((m,2)))#每个样本属于哪个簇,和簇心的欧式距离
#生成K个随机质心
centroids=createCent(dataSet,k)
clusterChanged=True#标识数据集还会不会在改变了
while clusterChanged:
clusterChanged=False
#对于每一个数据集
for i in range(m):
minDist=inf;minIndex=-1
for j in range(k):#计算样本i距离最近的质心 minIndex
distJI=distMeas(centroids[j,:],dataSet[i,:])#计算第j个质心距离第i个样本的欧式距离
if distJI<minDist:#
minDist=distJI#
minIndex=j#记录最近的质心
if clusterAssment[i,0]!=minIndex:#若当前样本距离最近的质心需要更改
clusterChanged=True
clusterAssment[i,:]=minIndex,minDist**2
#当更新完毕之后,重新计算每个簇的中心点
for cent in range(k):
ptsInClust=dataSet[nonzero(clusterAssment[:,0].A==cent)[0]]
#更新质心
centroids[cent,:]=mean(ptsInClust,axis=0)
return centroids,clusterAssment#返回稳定的质心,返回clusterAssment 该集合标识着样本i属于第j个簇,距离该簇的中心距离为clusterAssment[j][1]
#二分kMeans
#dataSet:初始数据集 K 质心个数 distMeas 欧式距离计算公式
def biKMeans(dataSet,k,distMeas=distEclud):
m=shape(dataSet)[0]#数据集样本个数
clusterAssment=mat(zeros((m,2)))
centroid0=mean(dataSet,axis=0).tolist()[0]#centroid0是初始质心
centList=[centroid0]
for j in range(m):
clusterAssment[j,1]=distMeas(mat(centroid0),dataSet[j,:])**2
while(len(centList)<k):#当质心的个数小于k的时候
lowestSSE=inf
#依次遍历所有的簇,在所有的簇上进行kmeans(2)的划分,找到具有最小的SSE(误差平方和)
for i in range(len(centList)):
#过滤簇i的数据集到ptsIncurrCluster
ptsInCurrCluster=dataSet[nonzero(clusterAssment[:,0].A==i)[0],:]
#尝试对其进行二分
centroidMat,splitClustAss=KMeans(ptsInCurrCluster,2,distMeas)
#计算划分部分的SSE
sseSplit=sum(splitClustAss[:,1])
#计算没有划分部分的SSE
sseNotSplit=sum(clusterAssment[nonzero(clusterAssment[:,0].A!=i)[0],1])
#若采用当前的划分方式能产生更小的SSE
if (sseNotSplit+sseSplit)<lowestSSE:
bestCentToSplit=i#计算该进一步划分的簇
bestNewCents=centroidMat#新产生的2个簇的质心
bestClustAss=splitClustAss.copy()#新产生的Clust记录数据
lowestSSE=sseSplit+sseNotSplit#更新lowestSSE
#更新簇分类结果
bestClustAss[nonzero(bestClustAss[:,0].A==1)[0],0]=len(centList)#产生新标号
bestClustAss[nonzero(bestClustAss[:,0].A==0)[0],0]=bestCentToSplit
#添加新的簇的质心
centList[bestCentToSplit]=bestNewCents[0,:].tolist()[0]#更新原质心的位置
centList.append(bestNewCents[1,:].tolist()[0])#增加了新的质心
clusterAssment[nonzero(clusterAssment[:,0].A==bestCentToSplit)[0],:]=bestClustAss#更新第bestCentToSplit簇中所有的样本点
return mat(centList),clusterAssment

import matplotlib
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
datMat=mat(loadDataSet('testSet2.txt'))
centList,myNewAssments=biKMeans(datMat,3)
print "二分KMeans:";
print centList
print " ";
cent,cluster=KMeans(datMat,3)
print cent

xcord0=[]
ycord0=[]
xcord1=[]
ycord1=[]
xcord2=[]
ycord2=[]
print shape(datMat)[0];
print shape(datMat)[1]
for i in range(shape(datMat)[0]):
# if myNewAssments[i,0]==0:
xcord1.append(datMat[:,0])
ycord1.append(datMat[:,1])
fig=plt.figure()
ax=fig.add_subplot(111)
print myNewAssments[:,0]
colors=array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
data_color=[colors[int(lbl)] for lbl in myNewAssments[:,0]]
ax.scatter(xcord1,ycord1,s=50,c=data_color)
x1=[]
y1=[]
x1.append(cent[0,0])
y1.append(cent[0,1])
#ax.scatter(x1,y1,s=100,color='red')

x2=[]
y2=[]
x2.append(cent[1,0])
y2.append(cent[1,1])
#ax.scatter(x2,y2,s=100,color='red')

x3=[]
y3=[]
x3.append(cent[2,0])
y3.append(cent[2,1])
#ax.scatter(x3,y3,s=100,color='red')

#x4=[]
#y4=[]
#x4.append(cent[3,0])
#y4.append(cent[3,1])
#ax.scatter(x4,y4,s=100,color='red')

x11=[]
y11=[]
x11.append(centList[0,0])
y11.append(centList[0,1])
ax.scatter(x11,y11,s=150,color='black')

x22=[]
y22=[]
x22.append(centList[1,0])
y22.append(centList[1,1])
ax.scatter(x22,y22,s=150,color='black')

x33=[]
y33=[]
x33.append(centList[2,0])
y33.append(centList[2,1])
ax.scatter(x33,y33,s=150,color='black')

#x44=[]
#y44=[]
#x44.append(centList[3,0])
#y44.append(centList[3,1])
#ax.scatter(x44,y44,s=100,color='green')
plt.show()

#print "cluster:";
#print cluster




内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: