您的位置:首页 > 其它

源码解读----之_kmeans_single_lloyd和_kmeans_single_elkan初始化质心的方法

2017-10-24 15:25 766 查看
源码解读----之_kmeans_single_lloyd和_kmeans_single_elkan初始化质心的方法,_init_centroids法被
_kmeans_single_lloyd和_kmeans_single_elkan调用,而_kmeans_single_lloyd和_kmeans_single_elkan被k_means方法调用


def _init_centroids(X, k, init, random_state=None, x_squared_norms=None,
init_size=None):
"""初始化质心
@:parameter X : array, shape (n_samples, n_features)样本数据
@:parameter k : int质心数
@:parameter init : {'k-means++', 'random' or ndarray or callable}初始化中心的方法
@:parameter random_state : int, RandomState instance or None, optional, default: None随机数生成器
@:parameter x_squared_norms :  array, shape (n_samples,), optional 每个点的欧氏距离的平方,如果已经计算了就不在计算
@:parameter init_size : int, optional 样本中选出随机样本,只有 batch KMeans算法才用,必须要小于K
@:return centers : array, shape(k, n_features)质心集合
"""
random_state = check_random_state(random_state)#检测随机数生成器
n_samples = X.shape[0]#样本的数目

if x_squared_norms is None:#如果没有计算欧式距离的平方则计算
x_squared_norms = row_norms(X, squared=True)

if init_size is not None and init_size < n_samples:#如果采用batchKmeans算法且样本数小
if init_size < k:#如果比K小则给出提示并设置它的值为K值的3倍
warnings.warn(
"init_size=%d should be larger than k=%d. "
"Setting it to 3*k" % (init_size, k),
RuntimeWarning, stacklevel=2)
init_size = 3 * k
init_indices = random_state.randint(0, n_samples, init_size)#返回一个随机数N,0<=N<=n_samples
X = X[init_indices]#取一个随机数据
x_squared_norms = x_squared_norms[init_indices]
n_samples = X.shape[0]#取X第一维的长度(样本总数)
elif n_samples < k:#否则ValueError
raise ValueError(
"n_samples=%d should be larger than k=%d" % (n_samples, k))

if isinstance(init, string_types) and init == 'k-means++':
#根据k-means++初始化质心
centers = _k_init(X, k, random_state=random_state,
x_squared_norms=x_squared_norms)
elif isinstance(init, string_types) and init == 'random':
#随机的选择初始化质心
seeds = random_state.permutation(n_samples)[:k]
centers = X[seeds]
elif hasattr(init, '__array__'):
# 如果给出了质心点,确保质心要和X是一样的类型,这是cython类型融合的必要条件
centers = np.array(init, dtype=X.dtype)
elif callable(init):
centers = init(X, k, random_state=random_state)
centers = np.asarray(centers, dtype=X.dtype)
else:
raise ValueError("the init parameter for the k-means should "
"be 'k-means++' or 'random' or an ndarray, "
"'%s' (type '%s') was passed." % (init, type(init)))

if sp.issparse(centers):
centers = centers.toarray()
#判断质心点的个数是否为K,质心点的维度是否和X的纬度一样
_validate_center_shape(X, k, centers)
return centers

本文是个人的理解,由于刚接触并且自身能力也有限,也许会存在误解,欢迎留言指正,本人一定虚心请教,谢谢

参考地址:http://scikit-learn.org/stable/modules/classes.html
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息