利用Theano理解深度学习——Logistic Regression
2016-01-02 20:07
471 查看
一、Logistic Regression
1、LR模型
Logistic回归是广义线性模型的一种,属于线性的分类模型,在其模型中主要有两个参数,即:权重矩阵W和偏置向量b。在Logistic回归中,主要是将输入向量映射到一组超平面,每一个超平面代表了一个类别。输入向量到超平面的距离表示的是输入向量属于对应的类别的成员的概率。对于输入向量x,其属于类别i的概率为:
P(Y=i∣x,W,b)=softmaxi(Wx+b)=eWix+bi∑jeWjx+bj
模型对于输入向量x的预测结果为ypred是所有类别的预测中概率值最大的,即:
ypred=argmaxiP(Y=i∣x,W,b)
通常使用的是其二分类的模型,即属于类别0或者类别1。
2、损失函数
在LR模型中,需要求解的参数为权重矩阵W和偏置向量b,为了求解模型的两个参数,首先必须定义损失函数。对于上述的多类别Logistic回归,可以使用Log似然函数作为其损失函数,但是通过Log似然函数求解其参数时,必须求解Log似然函数的极大值,即使用极大似然法估计参数。这等价于在模型参数为θ的条件下,在数据集D上最大化似然函数。其中,似然函数L为:L(θ={W,b},D)=∑i=0|D|log(P(Y=y(i)∣x(i),W,b))
为了方便,通常使用负的Log似然函数,即the negative log-likelihood(NLL)作为其损失函数,此时,需要计算的是NLL的极小值。损失函数l为:
l(θ={W,b},D)=−L(θ={W,b},D)
3、随机梯度下降法
为了求解LR模型中的参数,在上面定义了LR模型的损失函数,即NLL。此时,只需计算NLL的极小值条件下的参数θ,这样的参数便是LR模型中的参数。梯度下降法是求解优化问题的较为简单的方法,其基本思想是沿着损失函数的误差表面不断计算下降的方向。对于传统的批梯度下降法,有以下的伪代码:随机梯度下降法(Stochastic gradient descent,SGD)与传统的批梯度下降法的原则一致,都是选择最快的下降方向,但是,与批梯度不同的是,在选择下降方向时,批梯度是对所有的训练样本计算其梯度,而SGD仅仅是对一部分样本计算其梯度,通常情况下,在SGD中,通常选择根据一个样本计算其梯度,SGD的伪代码如下:
在深度学习算法的模型训练中,可以使用SGD的一个变种形式,称为“minibatches”。在Minibatch SGD中,其工作原理与SGD一致,其区别仅仅是在Minibatch SGD中,通过多个样本计算其梯度,而不是根据一个样本,但又不同于批梯度下降法中的根据整个训练集计算其梯度。根据所需样本量的大小,Minibatch SGD是出于SGD与批梯度之间的一种变形形式。其伪代码如下所示:
对于minibatch的大小B,若设置太大,这将会在计算梯度的过程中浪费很多的时间,最佳的方案是依据模型,数据集以及硬件综合选择minibatch的大小。
在LR模型的计算中,此时只需计算NLL的对于参数θ的梯度,通过迭代,便能计算出模型。
二、基于Theano的Logistic Regression实现解析
1、导入数据集
导入数据集的函数为load_data(dataset),具体的函数形式如下:
def load_data(dataset): '''导入数据 :type dataset: string :param dataset: MNIST数据集 ''' #1、处理文件目录 data_dir, data_file = os.path.split(dataset)#把路径分割成dirname和basename,返回一个元组 if data_dir == "" and not os.path.isfile(dataset): new_path = os.path.join( os.path.split(__file__)[0],#__file__表示的是当前的路径 ".", "data", dataset ) if os.path.isfile(new_path) or data_file == 'mnist.pkl.gz': dataset = new_path#文件所在的目录 if (not os.path.isfile(dataset)) and data_file == 'mnist.pkl.gz': import urllib origin = ( 'http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz' ) print 'Downloading data from %s' % origin urllib.urlretrieve(origin, dataset) print '... loading data' #2、打开文件 f = gzip.open(dataset, 'rb')# 打开一个gzip已经压缩好的gzip格式的文件,并返回一个文件对象:file object. train_set, valid_set, test_set = cPickle.load(f)#载入本地的文件 f.close() '''训练集train_set,验证集valid_set和测试集test_set的格式:元组(input, target) 其中,input是一个矩阵(numpy.ndarray),每一行代表一个样本;target是一个向量(numpy.ndarray),大小与input的行数对应 ''' def shared_dataset(data_xy, borrow=True): data_x, data_y = data_xy shared_x = theano.shared(numpy.asarray(data_x, dtype=theano.config.floatX), borrow=borrow) shared_y = theano.shared(numpy.asarray(data_y, dtype=theano.config.floatX), borrow=borrow) return shared_x, T.cast(shared_y, 'int32')#将shared_y转换成整型 #3、将数据处理成需要的形式 test_set_x, test_set_y = shared_dataset(test_set) valid_set_x, valid_set_y = shared_dataset(valid_set) train_set_x, train_set_y = shared_dataset(train_set) #4、返回数据集 rval = [(train_set_x, train_set_y), (valid_set_x, valid_set_y), (test_set_x, test_set_y)] return rval
需要导入的模块主要有
os、
gzip和
cPickle。其中
os模块主要用于在本地查找dataset文件,具有目录的处理以及文件的判断等函数;
gzip模块提供了一些简单的对文件进行压缩和解压缩的函数功能;
cPickle模块可以对任意一种类型的python对象进行序列化操作。
1、程序中的os
模块
在load_data(dataset)函数中,使用到的主要是
os.path模块,使用到的函数是:
os.path.split(path):把路径分割成dirname和basename,返回一个元组
os.path.isfile(path):判断路径是否为文件
os.path.join(path1[, path2[, ...]]):把目录和文件名合成一个路径
注:
__file__表示的是当前的路径
2、程序中的gzip
模块
gzip模块主要提供了一些简单的对文件进行压缩和解压缩的函数功能。使用到的函数是:
gzip.open(dataset, 'rb'): 打开一个gzip已经压缩好的gzip格式的文件,并返回一个文件对象:file object.
3、程序中的cPickle
模块
cPickle模块可以对任意一种类型的python对象进行序列化操作,使用到的函数是:
cPickle.load(file):主要是载入本地的文件。
在导入数据的过程中,将数据做成了带有存储性质的形式,这样的形式可以使得变量在不同的函数之间共享,具体的构造函数为
theano.shared()。
4、theano.shared()
函数
函数theano.shared()的格式如下:
如果设置
borrow=False,这表示在使用变量的过程中将是深拷贝,对数据的任何改变不会影响到原始的变量,通过控制该参数可以实现不同函数之间对变量的共享。
2、构建LogisticRegression
类
LogisticRegression类的代码如下所示:
class LogisticRegression(object): def __init__(self, input, n_in, n_out): """ 初始化参数 :type input: theano.tensor.TensorType :param input: 一个minibatch :type n_in: int :param n_in: 输入的特征的个数 :type n_out: int :param n_out: 输出单元的个数,即输出的类别个数,在本例中共有10个类别 """ #初始化参数W和b self.W = theano.shared(value=numpy.zeros((n_in, n_out), dtype=theano.config.floatX), name='W', borrow=True) self.b = theano.shared(value=numpy.zeros((n_out,), dtype=theano.config.floatX), name='b', borrow=True) #计算属于不同的类别的概率 self.p_y_given_x = T.nnet.softmax(T.dot(input, self.W) + self.b) # 计算所属的类别 self.y_pred = T.argmax(self.p_y_given_x, axis=1) #参数声明 self.params = [self.W, self.b] self.input = input def negative_log_likelihood(self, y): """负的log似然函数 :type y: theano.tensor.TensorType :param y: 对应的类别标签 """ return -T.mean(T.log(self.p_y_given_x)[T.arange(y.shape[0]), y]) '''T.arange(y.shape[0])返回的是一个向量[0,1,...,len(y)],y也是一个向量,如[3,5,6...,9],代表的是所属的类别 T.log(self.p_y_given_x)[T.arange(y.shape[0]), y]表示的是T.log(self.p_y_given_x)[0,3],mean函数内部是一个向量 ''' def errors(self, y): """计算在minibatch中的错误率 :type y: theano.tensor.TensorType :param y: 对应的类别标签 """ # 检查y与y_pred是否具有相同的维度 if y.ndim != self.y_pred.ndim: raise TypeError( 'y should have the same shape as self.y_pred', ('y', y.type, 'y_pred', self.y_pred.type) ) # 检查y的数据格式 if y.dtype.startswith('int'): #返回错误率 return T.mean(T.neq(self.y_pred, y)) else: raise NotImplementedError()
在
LogisticRegression类中主要有三个函数,构造函数
__init__(),负的log似然函数
negative_log_likelihood()和计算错误率函数
errors()。
1、构造函数__init__()
在构造函数中主要有这样一些函数,theano.shared()、
theano.tensor.nnet.softmax(),
theano.tensor.nnet.dot()和
theano.tensor.argmax()。其中,
theano.shared()在上面已经简单解释了;
theano.tensor.nnet.softmax()主要用于计算属于每一个类别的概率;
theano.tensor.nnet.dot()用于计算矩阵计算;
theano.tensor.argmax()用于返回最终所属的类别。
2、负的log似然函数negative_log_likelihood()
在负的log似然函数negative_log_likelihood()中使用到的函数是
theano.tensor.mean(),该函数用于计算均值。
3、计算错误率函数errors()
计算错误率函数用于在validation阶段和testing阶段对模型的评估,主要的思想是利用模型对验证集以及测试集进行预测,用预测的结果y_pred与样本标签
y进行对比,记录错误的个数,并返回错误的概率。用到的函数为
theano.tensor.neq()和
theano.tensor.mean(),函数
theano.tensor.neq(self.y_pred, y)用于统计
self.y_pred和
y中不相等的样本的个数。
3、sgd_optimization_mnist
函数
这个函数是整个Logistic回归算法的核心部分,用于构建整个算法的流程,该函数主要分为以下几个部分:导入数据集
建立模型
训练模型
1、导入数据集
导入函数部分的代码已在上面解释过了。处理数据集部分的代码如下:#1、导入数据集 datasets = load_data(dataset) train_set_x, train_set_y = datasets[0]#训练集 valid_set_x, valid_set_y = datasets[1]#验证集 #计算minibatches,得到训练集,验证集和测试集的minibatch的大小 n_train_batches = train_set_x.get_value(borrow=True).shape[0] / batch_size n_valid_batches = valid_set_x.get_value(borrow=True).shape[0] / batch_size
2、建立模型
在建立模型阶段,首先是一些全局符号变量的声明,然后初始化分类器,接着是构建好验证模型和训练模型,具有的代码如下:# 2、建模 print '... building the model' index = T.lscalar() # 声明一个符号变量,用于minibatch索引 # 声明符号变量x和y x = T.matrix('x') y = T.ivector('y') # 2.1 初始化分类器 classifier = LogisticRegression(input=x, n_in=28 * 28, n_out=10) cost = classifier.negative_log_likelihood(y) #验证模型 validate_model = theano.function( inputs=[index], outputs=classifier.errors(y), givens={ x: valid_set_x[index * batch_size: (index + 1) * batch_size], y: valid_set_y[index * batch_size: (index + 1) * batch_size] } ) # 计算梯度 g_W = T.grad(cost=cost, wrt=classifier.W) g_b = T.grad(cost=cost, wrt=classifier.b) # 对参数的更新 updates = [(classifier.W, classifier.W - learning_rate * g_W), (classifier.b, classifier.b - learning_rate * g_b)] # 模型的训练规则 train_model = theano.function( inputs=[index], outputs=cost, updates=updates, givens={ x: train_set_x[index * batch_size: (index + 1) * batch_size], y: train_set_y[index * batch_size: (index + 1) * batch_size] } )
3、训练模型
在模型的训练过程中,通过随机梯度下降不断调整模型中的参数W和b。在循环更新参数的过程中,停止的条件可以设置成循环的代数,也可以设置Early-stopping策略。Early-stopping策略可以有效地避免过拟合,主要是通过在验证集上观察模型的性能。#3、训练模型 print '... training the model' # early-stopping 参数 patience = 5000 patience_increase = 2 improvement_threshold = 0.995 validation_frequency = min(n_train_batches, patience / 2) best_validation_loss = numpy.inf#初始化一个最大的误差 start_time = timeit.default_timer()#计算开始的时间 done_looping = False#用于判断是否结束循环的标志 epoch = 0#当前的迭代次数 while (epoch < n_epochs) and (not done_looping): epoch = epoch + 1 #对每一个minibatch的数据进行训练 for minibatch_index in xrange(n_train_batches): minibatch_avg_cost = train_model(minibatch_index)#得到负的log似然值 #计算迭代的次数 iter = (epoch - 1) * n_train_batches + minibatch_index #每次将minibatch数据集计算一遍便开始计算validation if (iter + 1) % validation_frequency == 0: # 在验证集上验证模型的优劣 validation_losses = [validate_model(i) for i in xrange(n_valid_batches)] this_validation_loss = numpy.mean(validation_losses) print( 'epoch %i, minibatch %i/%i, validation error %f %%' % ( epoch, minibatch_index + 1, n_train_batches, this_validation_loss * 100. ) ) # 记录下在验证集上表现最好的模型 if this_validation_loss < best_validation_loss: #当性能足够好时不断提高patience以使得模型提早结束 if this_validation_loss < best_validation_loss * improvement_threshold: patience = max(patience, iter * patience_increase) best_validation_loss = this_validation_loss # 保存最优的模型 with open('best_model.pkl', 'w') as f: cPickle.dump(classifier, f) #提早退出循环 if patience <= iter: done_looping = True break end_time = timeit.default_timer()# 运行结束的时间 # 打印最优的验证结果 print( ( 'Optimization complete with best validation score of %f %%' ) % (best_validation_loss * 100.) ) # 打印运行的时间 print 'The code run for %d epochs, with %f epochs/sec' % ( epoch, 1. * epoch / (end_time - start_time)) print >> sys.stderr, ('The code for file ' + os.path.split(__file__)[1] + ' ran for %.1fs' % ((end_time - start_time)))
4、predict
函数
在predict函数中,使用到的是模型和测试数据集,具体的函数如下:
def predict(): """用训练好的模型进行预测 """ # 导入训练好的模型 classifier = cPickle.load(open('best_model.pkl')) # 建立预测模型 predict_model = theano.function( inputs=[classifier.input], outputs=classifier.y_pred) # 导入测试数据集 dataset='mnist.pkl.gz' datasets = load_data(dataset) test_set_x, test_set_y = datasets[2] test_set_x = test_set_x.get_value() predicted_values = predict_model(test_set_x[:10])#进行预测 print ("Predicted values for the first 10 examples in test set:") print predicted_values
使用的函数主要是导入函数和模型的函数,在上述都已经简单介绍过。
三、实验结果
1、训练模型
... loading data ... building the model ... training the model epoch 1, minibatch 83/83, validation error 12.458333 % epoch 2, minibatch 83/83, validation error 11.010417 % epoch 3, minibatch 83/83, validation error 10.312500 % epoch 4, minibatch 83/83, validation error 9.875000 % epoch 5, minibatch 83/83, validation error 9.562500 % epoch 6, minibatch 83/83, validation error 9.322917 % epoch 7, minibatch 83/83, validation error 9.187500 % epoch 8, minibatch 83/83, validation error 8.989583 % epoch 9, minibatch 83/83, validation error 8.937500 % epoch 10, minibatch 83/83, validation error 8.750000 % epoch 11, minibatch 83/83, validation error 8.666667 % epoch 12, minibatch 83/83, validation error 8.583333 % epoch 13, minibatch 83/83, validation error 8.489583 % epoch 14, minibatch 83/83, validation error 8.427083 % epoch 15, minibatch 83/83, validation error 8.354167 % epoch 16, minibatch 83/83, validation error 8.302083 % epoch 17, minibatch 83/83, validation error 8.250000 % epoch 18, minibatch 83/83, validation error 8.229167 % epoch 19, minibatch 83/83, validation error 8.260417 % epoch 20, minibatch 83/83, validation error 8.260417 % epoch 21, minibatch 83/83, validation error 8.208333 % epoch 22, minibatch 83/83, validation error 8.187500 % epoch 23, minibatch 83/83, validation error 8.156250 % epoch 24, minibatch 83/83, validation error 8.114583 % epoch 25, minibatch 83/83, validation error 8.093750 % epoch 26, minibatch 83/83, validation error 8.104167 % epoch 27, minibatch 83/83, validation error 8.104167 % epoch 28, minibatch 83/83, validation error 8.052083 % epoch 29, minibatch 83/83, validation error 8.052083 % epoch 30, minibatch 83/83, validation error 8.031250 % epoch 31, minibatch 83/83, validation error 8.010417 % epoch 32, minibatch 83/83, validation error 7.979167 % epoch 33, minibatch 83/83, validation error 7.947917 % epoch 34, minibatch 83/83, validation error 7.875000 % epoch 35, minibatch 83/83, validation error 7.885417 % epoch 36, minibatch 83/83, validation error 7.843750 % epoch 37, minibatch 83/83, validation error 7.802083 % epoch 38, minibatch 83/83, validation error 7.812500 % epoch 39, minibatch 83/83, validation error 7.812500 % epoch 40, minibatch 83/83, validation error 7.822917 % epoch 41, minibatch 83/83, validation error 7.791667 % epoch 42, minibatch 83/83, validation error 7.770833 % epoch 43, minibatch 83/83, validation error 7.750000 % epoch 44, minibatch 83/83, validation error 7.739583 % epoch 45, minibatch 83/83, validation error 7.739583 % epoch 46, minibatch 83/83, validation error 7.739583 % epoch 47, minibatch 83/83, validation error 7.739583 % epoch 48, minibatch 83/83, validation error 7.708333 % epoch 49, minibatch 83/83, validation error 7.677083 % epoch 50, minibatch 83/83, validation error 7.677083 % epoch 51, minibatch 83/83, validation error 7.677083 % epoch 52, minibatch 83/83, validation error 7.656250 % epoch 53, minibatch 83/83, validation error 7.656250 % epoch 54, minibatch 83/83, validation error 7.635417 % epoch 55, minibatch 83/83, validation error 7.635417 % epoch 56, minibatch 83/83, validation error 7.635417 % epoch 57, minibatch 83/83, validation error 7.604167 % epoch 58, minibatch 83/83, validation error 7.583333 % epoch 59, minibatch 83/83, validation error 7.572917 % epoch 60, minibatch 83/83, validation error 7.572917 % epoch 61, minibatch 83/83, validation error 7.583333 % epoch 62, minibatch 83/83, validation error 7.572917 % epoch 63, minibatch 83/83, validation error 7.562500 % epoch 64, minibatch 83/83, validation error 7.572917 % epoch 65, minibatch 83/83, validation error 7.562500 % epoch 66, minibatch 83/83, validation error 7.552083 % epoch 67, minibatch 83/83, validation error 7.552083 % epoch 68, minibatch 83/83, validation error 7.531250 % epoch 69, minibatch 83/83, validation error 7.531250 % epoch 70, minibatch 83/83, validation error 7.510417 % epoch 71, minibatch 83/83, validation error 7.520833 % epoch 72, minibatch 83/83, validation error 7.510417 % epoch 73, minibatch 83/83, validation error 7.500000 % Optimization complete with best validation score of 7.500000 % The code run for 74 epochs, with 2.780229 epochs/sec The code for file logistic_sgd.py ran for 26.6s
2、测试结果
... loading data Predicted values for the first 10 examples in test set: [7 2 1 0 4 1 4 9 6 9]
参考文献
Deep Learning Tutorials(http://www.deeplearning.net/tutorial/)
转载自:http://blog.csdn.net/google19890102/article/details/48976021
相关文章推荐
- iOS完全自学手册——[二]Hello World工程
- 设置搜索关键词的代码
- 嵌入式考试Shell编程题
- UIView动画
- 【成长訪谈】老翟:程序猿的企业家梦想
- longjmp()/setjmp()跳转
- PXC中文文档--第一章
- 数组指针与指针数组
- 1073. Scientific Notation (20)【字符串操作】——PAT (Advanced Level) Practise
- Xcode中常用的快捷键(原文链接http://www.cocoachina.com/ios/20141224/10752.html)
- 关于调用startActivityForResult()方法后Activity直接退出的问题原因和解决办法
- 核心动画 Core Animation
- wdcp 远程连接mysql 出现1130错误 解决办法
- PHP-MySQL扩展
- Python3 字典使用上的一个小细节
- C#窗体连接数据库出现未处理SqlException解决办法
- Java IO流详解
- 没R.java 这个文件 ,或R报错
- ubuntu下默认的头文件搜索路径
- 程序员称为高手的10条心得(摘自http://www.jizhuomi.com/software/394.html)