交叉验证
2015-10-13 22:20
316 查看
1. 验证集
在使用一个机器学习模型时,通常有一些参数需要设置,比如:KNN中的kk,距离函数;
SVM算法中的(C, gamma);
GBDT中的迭代次数,树的深度;
这些参数称为超参数(hyperparameters),好的参数可以极大提高算法的预测性能。选择合适的模型参数过程称为模型选择(model selection)。那么如何选择这些参数呢?在模型学习过程中,通常做法是将数据分为训练集和测试集,其中训练集用来训练模型,测试集用来预测模型在未知数据上的预测性能。需要注意的是,绝对不能用测试集来调整这些超参数。
用测试集来调整参数的危害之一是,模型可能在测试集上取得较好地预测性能,然而当我们实际部署模型时,却发现性能很差。实际上,模型对测试集产生了过拟合。换种思路看这个问题的话,若我们使用测试集调整参数,实际上我们已经将测试集当做训练集来使用,这样模型在看的见的数据上取得不错的性能,当部署模型到实际应用时,模型对于没见过的数据预测性能很差,也就是说模型泛化能力很弱。
正确的做法是在整个过程中,测试集只能被使用一次,而且是在最后一步。那么怎么样调整这些参数呢,可以将训练集分为两部分,其中数据多的部分用来训练模型,数据少的部分,用来调整参数,这部分也称为验证集。
2. 交叉验证
当我们用来训练模型的数据规模(包括训练集和验证集)不大时,将其中部分数据划分为验证集用来调整参数有些浪费,增加了模型过拟合的可能性。这时可以采用K重交叉验证的办法。K重交叉验证相比把数据集分为(test, validation, train sets)的做法可以充分利用所有的数据,另一方面,也可以避免过拟合。做法为:将训练集分为K份,选择其中K-1份作为train set,另一份作为validation set,训练K次,同时也测试K次,将K次的平均作为该参数下的validation结果。然后对于不同的参数,重复这样的训练,选择准确率最高的参数作为最后的参数。需要注意的是,在训练过程中不接触test set。5重交叉验证的一个例子如下图所示:
在python中使用K重交叉验证的一个示例代码为:
def calc_params(X, y, clf, param_values, param_name, K): # initialize training and testing scores with zeros train_scores = np.zeros(len(param_values)) test_scores = np.zeros(len(param_values)) # iterate over the different parameter values for i, param_value in enumerate(param_values): print param_name, ' = ', param_value # set classifier parameters clf.set_params(**{param_name:param_value}) # initialize the K scores obtained for each fold k_train_scores = np.zeros(K) k_test_scores = np.zeros(K) # create KFold cross validation cv = KFold(n_samples, K, shuffle=True, random_state=0) # iterate over the K folds for j, (train, test) in enumerate(cv): # fit the classifier in the corresponding fold # and obtain the corresponding accuracy scores on train and test sets clf.fit([X[k] for k in train], y[train]) k_train_scores[j] = clf.score([X[k] for k in train], y[train]) k_test_scores[j] = clf.score([X[k] for k in test], y[test]) # store the mean of the K fold scores train_scores[i] = np.mean(k_train_scores) test_scores[i] = np.mean(k_test_scores) # plot the training and testing scores in a log scale plt.semilogx(param_values, train_scores, alpha=0.4, lw=2, c='b') plt.semilogx(param_values, test_scores, alpha=0.4, lw=2, c='g') plt.xlabel(param_name + " values") plt.ylabel("Mean cross validation accuracy") # return the training and testing scores on each parameter value return train_scores, test_scores
采用3重交叉验证的函数调用代码:
alphas = np.logspace(-7, 0, 8) train_scores, test_scores = calc_params(X, y, clf, alphas, 'nb__alpha', 3) print 'training scores: ', train_scores print 'testing scores: ', test_scores
运行结果:
nb__alpha = 1e-07 nb__alpha = 1e-06 nb__alpha = 1e-05 nb__alpha = 0.0001 nb__alpha = 0.001 nb__alpha = 0.01 nb__alpha = 0.1 nb__alpha = 1.0 training scores: [ 1. 1. 1. 1. 1. 1. 0.99683333 0.97416667] testing scores: [ 0.7713 0.7766 0.7823 0.7943 0.8033 0.814 0.8073 0.7453]
可以看出,当参数为0.01时,取得最好的测试结果。
3. 交叉验证选择特征
交叉验证除了选择模型参数外,还可以用于特征的选择。对于分类或者回归算法来说,并不是说特征的数量越多越好,一般需要对提取到的特征进行选择(feature selection)。首先,需要对特征的预测能力进行排序,然后通过交叉验证,选择最优比例的特征组合,来作为最终使用的特征。下面给出一个决策树算法通过交叉验证选择最优特征比例的python代码:
from sklearn import cross_validation percentiles = range(1, 100, 5) results = [] for i in range(1, 100, 5): fs = feature_selection.SelectPercentile(feature_selection.chi2, percentile=i) X_train_fs = fs.fit_transform(X_train, y_train) scores = cross_validation.cross_val_score(dt, X_train_fs, y_train, cv=5) #print i,scores.mean() results = np.append(results, scores.mean()) optimal_percentil = np.where(results == results.max())[0] print "Optimal number of features:{0}".format(percentiles[optimal_percentil]), "\n"
运行结果为:
Optimal number of features:6 Mean scores: [ 0.83332303 0.87804576 0.87195424 0.86994434 0.87399505 0.86891363 0.86992373 0.86991342 0.87195424 0.86991342 0.87194393 0.87398475 0.86991342 0.87093383 0.86992373 0.86074005 0.86583179 0.86790353 0.86891363 0.8648423 ]
可以看出,最优的特征比例是6%,剩余的大部分特征是冗余的。
相关文章推荐
- iOS 键盘回收实现步骤
- Android官方API Guide学习之二 设备兼容性
- C语言数据结构-树
- Android 五大布局之(一) 线性布局和相对布局
- 编译错误 error C2451: “std::_Unforced”类型的条件表达式是非法的
- 转!!数据库 第一范式(1NF) 第二范式(2NF) 第三范式(3NF)的 联系和区别
- 工厂模式、控制反转及依赖注入
- 第 二 十 九 天:监 控 软 件 之 zabbix
- hadoop 数据挖掘
- 函数
- 使用trash-cli避免误删文件--为rm增加回收站功能
- Android 6.0中art虚拟机编译dex时已完全放弃使用LLVM
- layoutSubviews总结
- 整理一些工具
- 宏定义抽取单例
- hadoop运行到mapreduce.job: Running job后停止运行
- Visual Studio 命令别名
- 高性能IO模型浅析
- Java Swing intro
- Ajax教程