How to find out suitable parameters using cross validation
2014-03-02 17:35
337 查看
problem description:
Suppose we have a model with
one or more unknown parameters,
and a data set to which the model can be fit (the training data set).
The problem is how to find out suitable parameters to make the model fit the training data as well as possible. The answer to this can be using cross
validation. If you want to know more about what is cross validation ,please click this website:http://en.wikipedia.org/wiki/Cross-validation_(statistics). In this passage, I will implement this method in real code. Specifically, I will take sum rbf kernel for
example.
python: 2.7.0
machine-learning-tool:sklearn-learn
Code:
Suppose we have a model with
one or more unknown parameters,
and a data set to which the model can be fit (the training data set).
The problem is how to find out suitable parameters to make the model fit the training data as well as possible. The answer to this can be using cross
validation. If you want to know more about what is cross validation ,please click this website:http://en.wikipedia.org/wiki/Cross-validation_(statistics). In this passage, I will implement this method in real code. Specifically, I will take sum rbf kernel for
example.
Enviroment:
python: 2.7.0machine-learning-tool:sklearn-learn
Code:
def mainTrain(): crossSize = 10 data,label = splitData(data,label)#the the original sample randomly partitioned into crossSize equal size subsamples ''' data is specified like this [[1,2,3],[2,3,4.53],...] label is corresponding to data ''' C_range = 10.0**arange(-2,9)# options availabel for C gamma_range = 10.0 ** arange(-5, 4)#option availabel for gamma for c in C_range: for ga in gamma_range: trainingError = 0 for i in range(crossSize): trainingSet = [] trainingLabel = [] testSet = data[i] testLabel = label[i] error_time = 0 for j in range(crossSize): if(i!=j): trainingSet.extend(data[j]) trainingLabel.extend(label[j]) rbf_svc = svm.SVC(kernel='rbf',C=c,gamma=ga); rbf_svc.fit(trainingSet,trainingLabel) result = rbf_svc.predict(testSet) #pdb.set_trace() for i in range(len(testLabel)): if testLabel[i]!=result[i]:error_time+=1.0 tE = error_time/len(testLabel) trainingError+=tE print "error:%f " %(trainingError/crossSize) return def splitData(data,label): dataSize = len(data) crossSize = 10 pieceSize = dataSize/crossSize splitedData = [] splitedLabel = [] for i in range(crossSize-1): dataPiece = [] dataLabel = [] for j in range(pieceSize): randIndex = int(random.uniform(0,len(data))) dataPiece.append(data[randIndex]) dataLabel.append(label[randIndex]) del(data[randIndex]) del(label[randIndex]) splitedData.append(dataPiece) splitedLabel.append(dataLabel ) splitedData.append(data) splitedLabel.append(label) return splitedData,splitedLabel
相关文章推荐
- How to represent ROC curve when using Cross-Validation
- 如何找出相邻3条记录都满足同一条件(How to find out 3 continuous records all reach the same condition)
- How to find a node by its text using the GetNodeByText method
- [转]How Can I Find Out What Is Using a Busy or Reserved Serial Port?
- How do I find out Linux Resource utilization to detect system bottlenecks?
- How to find out why your account keeps getting locked with Windows Server, TMG and Webspy
- How to find the log I want when using 'git log'
- How to Change Default Web ADI Upload Parameters for FlexField Import / Validation
- How to find out the Sql Server version and service pack
- How to find indexPath for tapped button in tableView Using Seque
- How to Change Default Web ADI Upload Parameters for FlexField Import / Validation
- How to Change Default Web ADI Upload Parameters for FlexField Import / Validation
- how-to: resolve "java.lang.OutOfMemoryError: Java heap space" during using beeline && hiveserver2
- [Notes] Demo -- The practice about how to using SAP XI components to build up the mapping bridge cross the application layer
- How to find out the odd charactor?
- [转]How to find who is using / eating up the Virtual Address Space on your SQL Server
- How to Find a Blank Cell in Excel Using Vba
- [ios] how to findout the language you use on ios devices.
- How can I deliver parameters to a test function, that launched using adb shell am Instrumentation co
- How To Find Out Which Listview Column Was Right-Clicked