您的位置:首页 > 编程语言

统计学习方法第八章AdaBoost算法的例8.1代码实践

2018-01-29 18:45 309 查看
统计学习方法第八章AdaBoost算法的例8.1代码实践

#-*- coding: utf-8 -*-
from numpy import *

def loadDataSet():
dataSet=[[0,1,2,3,4,5,6,7,8,9]]
label=[1,1,1,-1,-1,-1,1,1,1,-1]
return mat(dataSet).T,mat(label)

def adaBoostTrain(dataSet,label,numIt=10):
classifDict=[]
m=shape(dataSet)[0]
totalRetResult=mat(zeros((m,1)))
weight=mat(ones((m,1))/m)
for i in xrange(numIt):
bestFeat,error,EstClass=decisionTree(dataSet,label,weight)
alpha=float(0.5*log((1-error)/error))
bestFeat['alpha']=alpha
classifDict.append(bestFeat)
wtx=multiply(-1*alpha*mat(label).T,EstClass)
weight=multiply(weight,exp(wtx))
weight=weight/sum(weight)
totalRetResult += alpha*EstClass
totalError = (sum(label.T != sign(totalRetResult))) / float(m)
if totalError==0:break
    return classifDict,totalRetResult

def splitDataSet(dataMat,feat,value,comp,m):
retArray=ones((m,1))
if comp=='LT':
retArray[dataMat[:,feat] <value] = -1.0
else:
retArray[dataMat[:,feat] >value] = -1.0
return  retArray

def decisionTree(dataSet,labelList,weight):
dataMat=mat(dataSet);labelMat=mat(labelList).T
bestFeat={}
minError=inf
m,n=shape(dataMat)
bestClass=mat(zeros((m,1)))
for i in range(n):
sortedIndex=argsort(dataMat,axis=i)
for j in range(m-1):
value=(dataMat[sortedIndex[j],i]+dataMat[sortedIndex[j+1],i])/2.0
for comp in ['LT', 'ST']:  # 符号可以是大于或者小于 LT:larger than    ST:small than
retArray=splitDataSet(dataMat,i,value,comp,m)
errSet=mat(ones((m,1)))
errSet[retArray == labelMat] =0
#print D,errSet
weightError=weight.T*errSet
#print weightError
if weightError<minError:
minError=weightError
bestFeat['feat']=i
bestFeat['value']=value
bestFeat['comp']=comp
bestClass=retArray.copy()
return bestFeat,minError,bestClass

dataSet,label=loadDataSet()
classifDict,totalRetResult=adaBoostTrain(dataSet,label)
print "classifDict",classifDict
print sign(totalRetResult)

执行结果如下:

classifDict [{'comp': 'ST', 'feat': 0, 'value': matrix([[ 2.5]]), 'alpha': 0.4236489301936017}, {'comp': 'ST', 'feat': 0, 'value': matrix([[ 8.5]]), 'alpha': 0.6496414920651304}, {'comp': 'LT', 'feat': 0, 'value': matrix([[ 5.5]]), 'alpha': 0.752038698388137}]
[[ 1.]
 [ 1.]
 [ 1.]
 [-1.]
 [-1.]
 [-1.]
 [ 1.]
 [ 1.]
 [ 1.]
 [-1.]]

made by zcl at CUMT

I know I can because I have a heart that beats


                                            
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  AI ML