您的位置:首页 > 其它

使用scikit-learn进行机器学习的简介(教程1)

2018-02-09 13:51 615 查看
一、机器学习:问题设定通常,一个学习问题是通过分析一些数据样本来尝试预测未知数据的属性。如果每一个样本不仅仅是一个单独的数字,比如一个多维的实例(multivariate data),也就是说有着多个属性特征我们可以把学习问题分成如下的几个大类:(1)有监督学习
数据带有我们要预测的属性。这种问题主要有如下几种:①分类
样例属于两类或多类,我们想要从已经带有标签的数据学习以预测未带标签的数据。识别手写数字就是一个分类问题,这个问题的主要目标就是把每一个输出指派到一个有限的类别中的一类。另一种思路去思考分类问题,其实分类问题是有监督学习中的离散形式问题。每一个都有一个有限的分类。对于样例提供的多个标签,我们要做的就是把未知类别的数据划分到其中的一种。
②回归
去过预期的输出包含连续的变量,那么这样的任务叫做回归。根据三文鱼的年纪和中联预测其长度就是一个回归样例。

(2)无监督学习
训练数据包含不带有目标值的输入向量x。对于这些问题,目标就是根据数据发现样本中相似的群组——聚类。或者在输入空间中判定数据的分布——密度估计,或者把数据从高维空间转换到低维空间以用于可视化
训练集和测试集

机器学习是学习一些数据集的特征属性并将其应用于新的数据。这就是为什么在机器学习用来评估算法时一般把手中的数据分成两部分。一部分我们称之为训练集,用以学习数据的特征属性。一部分我们称之为测试集,用以检验学习到的特征属性。

二、加载一个样本数据集

scikit-learn带有一些标准数据集。比如用来分类的iris数据集、digits数据集;用来回归的boston house price 数据集。接下来,我们我们从shell开启一个Python解释器并加载iris和digits两个数据集。【译注:一些代码惯例就不写了,提示符>>>之类的学过Python的都懂】[python] view plain copyprint?<code class="hljs" style="margin:0px; padding:0px">$ python  
>>><span class="hljs-keyword" style="margin:0px; padding:0px; color:rgb(0,0,255); line-height:1.8">from sklearn <span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">import datasets  
>>>iris = datasets.load_iris()  
>>>digits = datasets.load_digits()</span></span></code>  


$ python
>>>from sklearn import datasets
>>>iris = datasets.load_iris()
>>>digits = datasets.load_digits()
一个数据集是一个包含数据所有元数据的类字典对象。这个数据存储在 '.data'成员变量中,是一个$n*n$的数组,行表示样例,列表示特征。在有监督学习问题中,一个或多个响应变量(Y)存储在‘.target’成员变量中。不同数据集的更多细节可以在dedicated section中找到。例如,对于digits数据集,digits.data可以访问得到用来对数字进行分类的特征:[python] view plain copyprint?<code class="hljs" style="margin:0px; padding:0px">>>>print(digits.data)    
[[  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">5. ...,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.]  
 [  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0. ...,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">10.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.]  
 [  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0. ...,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">16.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">9.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.]  
 ...,  
 [  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">1. ...,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">6.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.]  
 [  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">2. ...,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">12.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.]  
 [  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">10. ...,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">12.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">1.   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.]]</span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></code>  


>>>print(digits.data)
[[  0.   0.   5. ...,   0.   0.   0.]
[  0.   0.   0. ...,  10.   0.   0.]
[  0.   0.   0. ...,  16.   9.   0.]
...,
[  0.   0.   1. ...,   6.   0.   0.]
[  0.   0.   2. ...,  12.   0.   0.]
[  0.   0.  10. ...,  12.   1.   0.]]
digits.target 就是数字数据集对应的真实数字值。也就是我们的程序要学习的。[python] view plain copyprint?<code class="hljs" style="margin:0px; padding:0px">>>>digits.target  
array([<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0, <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">1, <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">2, ..., <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">8, <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">9, <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">8])</span></span></span></span></span></span></code>  


>>>digits.target
array([0, 1, 2, ..., 8, 9, 8])
数据数组的形状

尽管原始数据也许有不同的形状,但实际使用的数据通常是一个二维数组(n个样例,n个特征)。对于数字数据集,每一个原始的样例是一张(8 x 8)的图片,也能被使用:[python] view plain copyprint?<code class="hljs" style="margin:0px; padding:0px">>>>digits.images[<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0]  
array([[  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">5.,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">13.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">9.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">1.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.],  
       [  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">13.,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">15.,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">10.,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">15.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">5.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.],  
       [  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">3.,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">15.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">2.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">11.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">8.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.],  
       [  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">4.,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">12.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">8.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">8.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.],  
       [  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">5.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">8.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">9.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">8.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.],  
       [  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">4.,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">11.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">1.,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">12.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">7.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.],  
       [  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">2.,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">14.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">5.,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">10.,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">12.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.],  
       [  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">6.,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">13.,  <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">10.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.,   <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.]])</span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></code>  


>>>digits.images[0]
array([[  0.,   0.,   5.,  13.,   9.,   1.,   0.,   0.],
[  0.,   0.,  13.,  15.,  10.,  15.,   5.,   0.],
[  0.,   3.,  15.,   2.,   0.,  11.,   8.,   0.],
[  0.,   4.,  12.,   0.,   0.,   8.,   8.,   0.],
[  0.,   5.,   8.,   0.,   0.,   9.,   8.,   0.],
[  0.,   4.,  11.,   0.,   1.,  12.,   7.,   0.],
[  0.,   2.,  14.,   5.,  10.,  12.,   0.,   0.],
[  0.,   0.,   6.,  13.,  10.,   0.,   0.,   0.]])

三、学习和预测

对于数字数据集(digits dataset),任务是预测一张图片中的数字是什么。数字数据集提供了0-9每一个数字的可能样例,可以用它们来对位置的数字图片进行拟合分类。在scikit-learn中,用以分类的拟合(评估)函数是一个Python对象,具体有fit(X,Y)和predic(T)两种成员方法。其中一个拟合(评估)样例是sklearn.svmSVC类,它实现了支持向量分类(SVC)。一个拟合(评估)函数的构造函数需要模型的参数,但是时间问题,我们将会把这个拟合(评估)函数作为一个黑箱:[python] view plain copyprint?<code class="hljs" style="margin:0px; padding:0px">>>><span class="hljs-keyword" style="margin:0px; padding:0px; color:rgb(0,0,255); line-height:1.8">from sklearn <span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">import svm  
>>>clf = svm.SVC(gamma=<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.001, C=<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">100.)</span></span></span></span></code>  


>>>from sklearn import svm
>>>clf = svm.SVC(gamma=0.001, C=100.)
选择模型参数

我们调用拟合(估测)实例clf作为我们的分类器。它现在必须要拟合模型,也就是说,他必须要学习模型。这可以通过把我们的训练集传递给fit方法。作为训练集,我们使用其中除最后一组的所有图像。我们可以通过Python的分片语法[:-1]来选取训练集,这个操作将产生一个新数组,这个数组包含digits.dataz中除最后一组数据的所有实例。[python] view plain copyprint?<code class="hljs" style="margin:0px; padding:0px">>>>clf.fit(digits.data[:-<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">1], digits.target[:-<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">1])    
SVC(C=<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">100.0, cache_size=<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">200, class_weight=<span class="hljs-keyword" style="margin:0px; padding:0px; color:rgb(0,0,255); line-height:1.8">None, coef0=<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.0, degree=<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">3,  
gamma=<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.001, kernel=<span class="hljs-string" style="margin:0px; padding:0px; color:rgb(163,21,21); line-height:1.8">'rbf', max_iter=-<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">1, probability=<span class="hljs-keyword" style="margin:0px; padding:0px; color:rgb(0,0,255); line-height:1.8">False,  
random_state=<span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">None, shrinking=<span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">True, tol=<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.001, verbose=<span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">False)</span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></code>  


>>>clf.fit(digits.data[:-1], digits.target[:-1])
SVC(C=100.0, cache_size=200, class_weight=None, coef0=0.0, degree=3,
gamma=0.001, kernel='rbf', max_iter=-1, probability=False,
random_state=None, shrinking=True, tol=0.001, verbose=False)
现在你可以预测新的数值了。我们可以让这个训练器告诉我们digits数据集我们没有作为训练数据使用的最后一张图像是什么数字。[python] view plain copyprint?<code class="hljs" style="margin:0px; padding:0px">>>>clf.predict(digits.data[-<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">1])  
array([<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">8])</span></span></code>  


>>>clf.predict(digits.data[-1])
array([8])
相应的图片如下图:
正如你所看到的,这是一个很有挑战的任务:这张图片的分辨率很低。你同意分类器给出的答案吗?这个分类问题的完整示例在这里识别手写数字,你可以运行并使用它。[译:看本文附录]

四、模型持久化

可以使用Python的自带模块——pickle来保存scikit中的模型:[python] view plain copyprint?<code class="hljs" style="margin:0px; padding:0px">>>><span class="hljs-keyword" style="margin:0px; padding:0px; color:rgb(0,0,255); line-height:1.8">from sklearn <span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">import svm  
>>><span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">from sklearn <span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">import datasets  
>>>clf = svm.SVC()  
>>>iris = datasets.load_iris()  
>>>X, y = iris.data, iris.target  
>>>clf.fit(X, y)    
SVC(C=<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">1.0, cache_size=<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">200, class_weight=<span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">None, coef0=<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.0, degree=<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">3, gamma=<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.0,  
  kernel=<span class="hljs-string" style="margin:0px; padding:0px; color:rgb(163,21,21); line-height:1.8">'rbf', max_iter=-<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">1, probability=<span class="hljs-keyword" style="margin:0px; padding:0px; color:rgb(0,0,255); line-height:1.8">False, random_state=<span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">None,  
  shrinking=<span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">True, tol=<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.001, verbose=<span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">False)  
  
>>><span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">import pickle  
>>>s = pickle.dumps(clf)  
>>>clf2 = pickle.loads(s)  
>>>clf2.predict(X[<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0])  
array([<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0])  
>>>y[<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0]  
<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0</span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></code>  


>>>from sklearn import svm
>>>from sklearn import datasets
>>>clf = svm.SVC()
>>>iris = datasets.load_iris()
>>>X, y = iris.data, iris.target
>>>clf.fit(X, y)
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.0,
kernel='rbf', max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False)

>>>import pickle
>>>s = pickle.dumps(clf)
>>>clf2 = pickle.loads(s)
>>>clf2.predict(X[0])
array([0])
>>>y[0]
0
对于scikit,也许使用joblib的pickle替代——(joblib.dump&joblib.load)更有趣。因为它在处理带数据时更高效。但是遗憾的是它只能把数据持久化到硬盘而不是一个字符串(译注:搬到string字符串意味着数据在内存中):[python] view plain copyprint?<code class="hljs" style="margin:0px; padding:0px">>>><span class="hljs-keyword" style="margin:0px; padding:0px; color:rgb(0,0,255); line-height:1.8">from sklearn.externals <span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">import joblib  
>>>joblib.dump(clf, <span class="hljs-string" style="margin:0px; padding:0px; color:rgb(163,21,21); line-height:1.8">'filename.pkl')</span></span></span></code>  


>>>from sklearn.externals import joblib
>>>joblib.dump(clf, 'filename.pkl')
往后你就可以加载这个转储的模型(也能在另一个Python进程中使用),如下:[python] view plain copyprint?<code class="hljs" style="margin:0px; padding:0px">>>>clf = joblib.load(<span class="hljs-string" style="margin:0px; padding:0px; color:rgb(163,21,21); line-height:1.8">'filename.pkl')</span></code>  


>>>clf = joblib.load('filename.pkl')
注意

joblib.dump返回一个文件名的列表,每一个numpy数组元素包含一个clf在文件系统上的名字,在用joblib.load加载的时候所有的文件需要在相同的文件夹下注意pickle有一些安全和可维护方面的问题。请参考Model persistent 获得在scikit-learn中模型持久化的细节。五、惯例约定scikit-learn的各种拟合(评估)函数遵循一些确定的规则以使得他们的用法能够被预想到(译:使得各种学习方法的用法统一起来)①类型转换除非特别指定,输入将被转换为float64[python] view plain copyprint?<code class="hljs" style="margin:0px; padding:0px"><span class="hljs-keyword" style="margin:0px; padding:0px; color:rgb(0,0,255); line-height:1.8">import numpy  
<span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">from sklearn <span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">import random_projection  
rng = np.random.RandomState(<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0)  
X = rng.rand(<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">10,<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">2000)  
X = np.array(X,dtype =<span class="hljs-string" style="margin:0px; padding:0px; color:rgb(163,21,21); line-height:1.8">'float32')  
<span class="hljs-keyword" style="margin:0px; padding:0px; color:rgb(0,0,255); line-height:1.8">print x.dtype  
transformer = random_projection.GaussianRandomProjection()  
X_new = transformer.fit_transform(X)  
<span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">print X_new.dtype</span></span></span></span></span></span></span></span></span></code>  


import numpy
from sklearn import random_projection
rng = np.random.RandomState(0)
X = rng.rand(10,2000)
X = np.array(X,dtype ='float32')
print x.dtype
transformer = random_projection.GaussianRandomProjection()
X_new = transformer.fit_transform(X)
print X_new.dtype
在这个例子中,X是float32,被fit_transform(X)转换成float64,回归被转换成float64,分类目标维持不变.
[python] view plain copyprint?<code class="hljs" style="margin:0px; padding:0px"><span class="hljs-keyword" style="margin:0px; padding:0px; color:rgb(0,0,255); line-height:1.8">from sklearn <span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">import datesets  
<span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">from sklearn.svm <span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">import SVC  
iris = datasets.load_iris()  
clf =SVC()  
clf.fit(iris.data,iris.target)  
<span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">print list(clf.predict(iris.data[:<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">3]))  
clf.fit(iris.data,iris.target_names[iris.target])  
<span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">print list(clf.predict(iris.data[:<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">3]))</span></span></span></span></span></span></span></span></code>  


from sklearn import datesets
from sklearn.svm import SVC
iris = datasets.load_iris()
clf =SVC()
clf.fit(iris.data,iris.target)
print list(clf.predict(iris.data[:3]))
clf.fit(iris.data,iris.target_names[iris.target])
print list(clf.predict(iris.data[:3]))
这里第一个predict()返回一个整数数组,是因为iris.target(一个整数数组)被用于拟合。第二个predict()返回一个字符串数组,因为iris.target_names被用于拟合。②重拟合和更新参数
一个拟合(评估)函数的混合参数(超参数)能够在通过sklearn.pipeline.Pipeline.set_params方法构造之后被更新。多次调用fit()能够覆写之前fit()学习的内容:[python] view plain copyprint?<code class="hljs" style="margin:0px; padding:0px"><span class="hljs-keyword" style="margin:0px; padding:0px; color:rgb(0,0,255); line-height:1.8">import numpy <span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">as np  
<span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">from sklearn.svm <span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">import SVC  
rng = np.random.RandomState(<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0);  
X = rng.rand(<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">100,<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">10)  
Y = rng.binomial(<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">1,<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.5,<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">100)  
X_test = rng.rand(<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">5,<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">10)  
clf = SVC()  
clf.set_params(kernel = <span class="hljs-string" style="margin:0px; padding:0px; color:rgb(163,21,21); line-height:1.8">'linear').fit(X,Y)  
clf.predict(X_test)  
clf.set_params(kernel=<span class="hljs-string" style="margin:0px; padding:0px; line-height:1.8">'rbf').fit(X,Y)  
clf.predict(X_test) </span></span></span></span></span></span></span></span></span></span></span></span></span></span></code>  


import numpy as np
from sklearn.svm import SVC
rng = np.random.RandomState(0);
X = rng.rand(100,10)
Y = rng.binomial(1,0.5,100)
X_test = rng.rand(5,10)
clf = SVC()
clf.set_params(kernel = 'linear').fit(X,Y)
clf.predict(X_test)
clf.set_params(kernel='rbf').fit(X,Y)
clf.predict(X_test)
这里,用SVC()构造之后,开始拟合(评估)函数默认的'rbf'核被改编成'linear',后来又改回'rbf'去重拟合做第二次的预测。

附:

①digits数据集:一个展示怎样用scikit-learn识别手写数字的样例:绘制数字:[python] view plain copyprint?<code class="hljs" style="margin:0px; padding:0px"><span class="hljs-comment" style="margin:0px; padding:0px; color:green; line-height:1.8"># Code source: Gaël Varoquaux  
<span class="hljs-comment" style="margin:0px; padding:0px; line-height:1.8"># Modified for documentation by Jaques Grobler  
<span class="hljs-comment" style="margin:0px; padding:0px; line-height:1.8"># License: BSD 3 clause  
<span class="hljs-keyword" style="margin:0px; padding:0px; color:rgb(0,0,255); line-height:1.8">from sklearn <span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">import datasets  
<span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">import matplotlib.pyplot <span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">as plt  
<span class="hljs-comment" style="margin:0px; padding:0px; color:green; line-height:1.8">#Load the digits dataset  
digits = datasets.load_digits()  
<span class="hljs-comment" style="margin:0px; padding:0px; line-height:1.8">#Display the first digit  
plt.figure(<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">1, figsize=(<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">3, <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">3))  
plt.imshow(digits.images[-<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">1], cmap=plt.cm.gray_r, interpolation=<span class="hljs-string" style="margin:0px; padding:0px; color:rgb(163,21,21); line-height:1.8">'nearest')  
plt.show()</span></span></span></span></span></span></span></span></span></span></span></span></span></span></code>  


# Code source: Gaël Varoquaux
# Modified for documentation by Jaques Grobler
# License: BSD 3 clause
from sklearn import datasets
import matplotlib.pyplot as plt
#Load the digits dataset
digits = datasets.load_digits()
#Display the first digit
plt.figure(1, figsize=(3, 3))
plt.imshow(digits.images[-1], cmap=plt.cm.gray_r, interpolation='nearest')
plt.show()

②绘制数字分类 (plot_digits_classification.py)[python] view plain copyprint?<code class="hljs" style="margin:0px; padding:0px"><span class="hljs-comment" style="margin:0px; padding:0px; color:green; line-height:1.8"># Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>  
<span class="hljs-comment" style="margin:0px; padding:0px; line-height:1.8"># License: BSD 3 clause  
<span class="hljs-comment" style="margin:0px; padding:0px; line-height:1.8"># Standard scientific Python imports  
<span class="hljs-keyword" style="margin:0px; padding:0px; color:rgb(0,0,255); line-height:1.8">import matplotlib.pyplot <span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">as plt  
<span class="hljs-comment" style="margin:0px; padding:0px; color:green; line-height:1.8"># Import datasets, classifiers and performance metrics  
<span class="hljs-keyword" style="margin:0px; padding:0px; color:rgb(0,0,255); line-height:1.8">from sklearn <span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">import datasets, svm, metrics  
<span class="hljs-comment" style="margin:0px; padding:0px; color:green; line-height:1.8"># The digits dataset  
digits = datasets.load_digits()  
<span class="hljs-comment" style="margin:0px; padding:0px; line-height:1.8"># The data that we are interested in is made of 8x8 images of digits, let's  
<span class="hljs-comment" style="margin:0px; padding:0px; line-height:1.8"># have a look at the first 3 images, stored in the `images` attribute of the  
<span class="hljs-comment" style="margin:0px; padding:0px; line-height:1.8"># dataset.  If we were working from image files, we could load them using  
<span class="hljs-comment" style="margin:0px; padding:0px; line-height:1.8"># pylab.imread.  Note that each image must have the same size. For these  
<span class="hljs-comment" style="margin:0px; padding:0px; line-height:1.8"># images, we know which digit they represent: it is given in the 'target' of  
<span class="hljs-comment" style="margin:0px; padding:0px; line-height:1.8"># the dataset.  
images_and_labels = list(zip(digits.images, digits.target))  
<span class="hljs-keyword" style="margin:0px; padding:0px; color:rgb(0,0,255); line-height:1.8">for index, (image, label) <span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">in enumerate(images_and_labels[:<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">4]):  
    plt.subplot(<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">2, <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">4, index + <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">1)  
    plt.axis(<span class="hljs-string" style="margin:0px; padding:0px; color:rgb(163,21,21); line-height:1.8">'off')  
    plt.imshow(image, cmap=plt.cm.gray_r, interpolation=<span class="hljs-string" style="margin:0px; padding:0px; line-height:1.8">'nearest')  
    plt.title(<span class="hljs-string" style="margin:0px; padding:0px; line-height:1.8">'Training: %i' % label)  
<span class="hljs-comment" style="margin:0px; padding:0px; color:green; line-height:1.8"># To apply a classifier on this data, we need to flatten the image, to  
<span class="hljs-comment" style="margin:0px; padding:0px; line-height:1.8"># turn the data in a (samples, feature) matrix:  
n_samples = len(digits.images)  
data = digits.images.reshape((n_samples, -<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">1))  
<span class="hljs-comment" style="margin:0px; padding:0px; line-height:1.8"># Create a classifier: a support vector classifier  
classifier = svm.SVC(gamma=<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">0.001)  
<span class="hljs-comment" style="margin:0px; padding:0px; line-height:1.8"># We learn the digits on the first half of the digits  
classifier.fit(data[:n_samples / <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">2], digits.target[:n_samples / <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">2])  
<span class="hljs-comment" style="margin:0px; padding:0px; line-height:1.8"># Now predict the value of the digit on the second half:  
expected = digits.target[n_samples / <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">2:]  
predicted = classifier.predict(data[n_samples / <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">2:])  
print(<span class="hljs-string" style="margin:0px; padding:0px; color:rgb(163,21,21); line-height:1.8">"Classification report for classifier %s:\n%s\n"  
      % (classifier, metrics.classification_report(expected, predicted)))  
print(<span class="hljs-string" style="margin:0px; padding:0px; line-height:1.8">"Confusion matrix:\n%s" % metrics.confusion_matrix(expected, predicted))  
images_and_predictions = list(zip(digits.images[n_samples / <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">2:], predicted))  
<span class="hljs-keyword" style="margin:0px; padding:0px; color:rgb(0,0,255); line-height:1.8">for index, (image, prediction) <span class="hljs-keyword" style="margin:0px; padding:0px; line-height:1.8">in enumerate(images_and_predictions[:<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">4]):  
    plt.subplot(<span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">2, <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">4, index + <span class="hljs-number" style="margin:0px; padding:0px; line-height:1.8">5)  
    plt.axis(<span class="hljs-string" style="margin:0px; padding:0px; color:rgb(163,21,21); line-height:1.8">'off')  
    plt.imshow(image, cmap=plt.cm.gray_r, interpolation=<span class="hljs-string" style="margin:0px; padding:0px; line-height:1.8">'nearest')  
    plt.title(<span class="hljs-string" style="margin:0px; padding:0px; line-height:1.8">'Prediction: %i' % prediction)  
plt.show()</span></span></span></span></span></span></span></span></span></span></span></span></span></span></
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: 
相关文章推荐