您的位置:首页 > 产品设计 > UI/UE

How to find out suitable parameters using cross validation

2014-03-02 17:35 337 查看
problem description:

Suppose we have a model with
one or more unknown parameters,
and a data set to which the model can be fit (the training data set).

The problem is how to find out suitable parameters to make the model fit the training data as well as possible. The answer to this can be using cross
validation. If you want to know more about what is cross validation ,please click this website:http://en.wikipedia.org/wiki/Cross-validation_(statistics). In this passage, I will implement this method in real code. Specifically, I will take sum rbf kernel for
example.


Enviroment:

python: 2.7.0
machine-learning-tool:sklearn-learn

Code:
def mainTrain():
crossSize = 10
data,label = splitData(data,label)#the the original sample randomly partitioned into crossSize equal size subsamples
'''
data is specified like this
[[1,2,3],[2,3,4.53],...]
label is corresponding to data

'''
C_range = 10.0**arange(-2,9)# options availabel for C
gamma_range = 10.0 ** arange(-5, 4)#option availabel for gamma
for c in C_range:
for ga in gamma_range:
trainingError = 0
for i in range(crossSize):
trainingSet = []
trainingLabel = []
testSet = data[i]
testLabel = label[i]
error_time = 0
for j in range(crossSize):
if(i!=j):
trainingSet.extend(data[j])
trainingLabel.extend(label[j])

rbf_svc = svm.SVC(kernel='rbf',C=c,gamma=ga);
rbf_svc.fit(trainingSet,trainingLabel)

result =  rbf_svc.predict(testSet)
#pdb.set_trace()
for i in range(len(testLabel)):
if testLabel[i]!=result[i]:error_time+=1.0
tE = error_time/len(testLabel)
trainingError+=tE
print "error:%f " %(trainingError/crossSize)
return

def splitData(data,label):
dataSize = len(data)
crossSize = 10
pieceSize = dataSize/crossSize
splitedData  = []
splitedLabel = []
for i in range(crossSize-1):
dataPiece = []
dataLabel = []
for j in range(pieceSize):
randIndex = int(random.uniform(0,len(data)))
dataPiece.append(data[randIndex])
dataLabel.append(label[randIndex])

del(data[randIndex])
del(label[randIndex])
splitedData.append(dataPiece)
splitedLabel.append(dataLabel )
splitedData.append(data)
splitedLabel.append(label)
return splitedData,splitedLabel
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: 
相关文章推荐