您的位置:首页 > 其它

关于cs231n作业1中KNN算法实现的感悟

2018-02-26 16:31 295 查看
1、避免使用循环,直接使用数组进行计算(一般需要广播)可以很好的提高计算效率,一般情况下都是可以在原来的维度下找到相应的表达式来代替循环,切记最好不要轻易提高维度来解决问题,因为这会大大增加计算量。
>>> import numpy as np
>>> train = np.arange(30).reshape(6,5)
>>> test = np.arange(15).reshape(3,5)
 #  训练集共6个目标
# 测试集共3个目标
# 需实现L2距离的计算,无循环
# 方法一:最开始想的是通过train变为(1,6,5),test变为(3,1,5),直接用(train-test)**2后沿着2轴求和后开方即可,后发现由于计算时间较长,且内存消耗巨大,甚至出现memoryerro
>>> train = train[np.newaxis,:,:]
>>> test = test[:,np.newaxis,:]
>>> a = np.sum((train-test)**2,axis = 2)
>>> a**0.5
array([[ 0.        , 11.18033989, 22.36067977, 33.54101966, 44.72135955,
        55.90169944],
       [11.18033989,  0.        , 11.18033989, 22.36067977, 33.54101966,
        44.72135955],
       [22.36067977, 11.18033989,  0.        , 11.18033989, 22.36067977,
        33.54101966]])
 #这种方法体现的在二维里无法很好的解决问题时,可扩展至三维进行解决,但是增大了计算量,对内存要求巨大。
#方法二:(最佳方法)

>>> train = np.arange(30).reshape(6,5)
>>> test = np.arange(15).reshape(3,5)
>>> b = np.dot(test,train.T)
>>> c = np.sum(train**2,axis=1)
>>> e =np.sum(test**2,axis = 1).reshape(3,1)
>>> b = -2*b
>>> (b+c+e)**0.5
array([[ 0.        , 11.18033989, 22.36067977, 33.54101966, 44.72135955,
        55.90169944],
       [11.18033989,  0.        , 11.18033989, 22.36067977, 33.54101966,
        44.72135955],
       [22.36067977, 11.18033989,  0.        , 11.18033989, 22.36067977,
        33.54101966]])

可以发现:两种方法算出来的L2距离是相同的,但是效率相差很多,后者速度很快。
2、在得到K个最接近的类别后,需要算出哪个类别出现的次数最多,这里用到了一个对正整数一维数组适用的计算出现次数最多的元素的函数np.argmax(np.bincount(a)),非常好用。
3、mp.array_split(a)可以用来切割数组,返回数组列表,交叉验证就是指从被分割的4部分中轮流抽取一部分充当验证集,其余三部分继续充当训练集
>>> train
array([[[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9]],

       [[10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]],

       [[20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29]]])
>>> np.array_split(train,3)
[array([[[0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9]]]), array([[[10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]]]), array([[[20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29]]])]

4、np.array(a)可以用来组合列表,a可以为一维数组的列表,一般常用列表解析表达式表示。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: