您的位置:首页 > 其它

使用TensorFlow高级API实现kmeans聚类

2018-03-29 11:06 459 查看
TensorFlow可以用来解决很多机器学习问题。TensorFlow提供了tf.contrib.factorization.KMeansClustering高级API可以十分方便地实现聚类。
下面以经典的iris花数据集为例,实现一个简单的聚类demo。
首先导入数据,从sklearn的datasets中导入iris数据集。
然后调用api实现一个聚类函数。值得注意地是此处使用了tf.train.limit_epochs来作为数据读入函数。
最后实现一个聚类结果展示函数,比较聚类结果和实际类别间的差异,可以发现聚类算法可以较好的划分类别。
备注:
完整代码:
import numpy as np
from sklearn import datasets
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import tensorflow as tf
import warnings
#warnings.filterwarnings("ignore")

def loadData(iris):
    X=iris.data
    y=iris.target
    return X,y
def kmeansCluster(X,numClusters):
    get_inputs=lambda: tf.train.limit_epochs(tf.convert_to_tensor(X, dtype=tf.float32), num_epochs=1)
    cluster = tf.contrib.factorization.KMeansClustering(num_clusters=numClusters,
                                                      initial_clusters=tf.contrib.factorization.KMeansClustering.KMEANS_PLUS_PLUS_INIT)
    cluster.train(input_fn=get_inputs, steps=2000)
    y_pred=cluster.predict_cluster_index(input_fn=get_inputs)
    y_pred=np.asarray(list(y_pred))
    return y_pred
def plotFigure(fignum,title, X,y):
    fig = plt.figure(fignum, figsize=(8,6))
    ax = Axes3D(fig, rect=[0, 0, .95, 1], elev=48, azim=134)
    ax.scatter(X[:, 3], X[:, 0], X[:, 2],
               c=y.astype(np.float), edgecolor='k')
    ax.w_xaxis.set_ticklabels([])
    ax.w_yaxis.set_ticklabels([])
    ax.w_zaxis.set_ticklabels([])
    ax.set_xlabel('Petal width')
    ax.set_ylabel('Sepal length')
    ax.set_zlabel('Petal length')
    ax.set_title(title)
    ax.dist = 10
    fig.show()

if __name__ == '__main__':
    X,y = loadData(datasets.load_iris())
    y_pred = kmeansCluster(X,3)
    plotFigure(1,"3 clusters",X,y_pred)
    plotFigure(2,"Ground Truth",X,y)
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: