梯度下降法实现softmax回归MATLAB程序
2016-06-29 17:10
711 查看
梯度下降法实现softmax回归MATLAB程序
版权声明:本文原创,转载须注明来源。解决二分类问题时我们通常用Logistic回归,而解决多分类问题时若果用Logistic回归,则需要设计多个分类器,这是相当麻烦的事情。softmax回归可以看做是Logistic回归的普遍推广(Logistic回归可看成softmax回归在类别数为2时的特殊情况),在多分类问题上softmax回归是一个有效的工具。
关于softmax回归算法的理论知识可参考这两篇博文:http://deeplearning.stanford.edu/wiki/index.php/Softmax%E5%9B%9E%E5%BD%92 ;
http://blog.csdn.net/acdreamers/article/details/44663305 。
本文自编mysoftmax_gd函数用于实现梯度下降softmax回归,代码如下(链接:http://pan.baidu.com/s/1geF2WMJ 密码:9x3x):
MATLAB程序代码:
function [theta,test_pre,rate] = mysoftmax_gd(X_test,X,label,lambda,alpha,MAX_ITR,varargin) % 该函数用于实现梯度下降法softmax回归 % 调用方式:[theta,test_pre,rate] = mysoftmax_gd(X_test,X,label,lambda,alpha,MAX_ITR,varargin) % X_test:测试输入数据 % X:训练输入数据,组织为m*p矩阵,m为案例个数,p为加上常数项之后的属性个数 % label:训练数据标签,组织为m*1向量(数值型) % lambda:权重衰减参数weight decay parameter % alpha:梯度下降学习速率 % MAX_ITR:最大迭代次数 % varargin:可选参数,输入初始迭代的theta系数,若不输入,则默认随机选取 % theta:梯度下降法的theta系数寻优结果 % test_pre:测试数据预测标签 % rata:训练数据回判正确率 % Genlovy Hoo,2016.06.29. genlovhyy@163.com %% 梯度下降寻优 Nin=length(varargin); if Nin>1 error('输入太多参数') % 若可选输入参数超过1个,则报错 end [m,p] = size(X); numClasses = length(unique(label)); % 求取标签类别数 if Nin==0 theta = 0.005*randn(p,numClasses); % 若没有输入可选参数,则随机初始化系数 else theta=varargin{1}; % 若有输入可选参数,则将其设定为初始theta系数 end cost=zeros(MAX_ITR,1); % 用于追踪代价函数的值 for k=1:MAX_ITR [cost(k),grad] = softmax_cost_grad(X,label,lambda,theta); % 计算代价函数值和梯度 theta=theta-alpha*grad; % 更新系数 end %% 回判预测 [~,~,Probit] = softmax_cost_grad(X,label,lambda,theta); [~,label_pre] = max(Probit,[],2); index = find(label==label_pre); % 找出预测正确的样本的位置 rate = length(index)/m; % 计算预测精度 %% 绘制代价函数图 figure('Name','代价函数值变化图'); plot(0:MAX_ITR-1,cost) xlabel('迭代次数'); ylabel('代价函数值') title('代价函数值变化图');% 绘制代价函数值变化图 %% 测试数据预测 [mt,pt] = size(X_test); Probit_t = zeros(mt,length(unique(label))); for smpt = 1:mt Probit_t(smpt,:) = exp(X_test(smpt,:)*theta)/sum(exp(X_test(smpt,:)*theta)); end [~,test_pre] = max(Probit_t,[],2);
function [cost,thetagrad,P] = softmax_cost_grad(X,label,lambda,theta) % 用于计算代价函数值及其梯度 % X:m*p输入矩阵,m为案例个数,p为加上常数项之后的属性个数 % label:m*1标签向量(数值型) % lambda:权重衰减参数weight decay parameter % theta:p*k系数矩阵,k为标签类别数 % cost:总代价函数值 % thetagrad:梯度矩阵 % P:m*k分类概率矩阵,P(i,j)表示第i个样本被判别为第j类的概率 m = size(X,1); % 将每个标签扩展为一个k维横向量(k为标签类别数),若样本i属于第j类,则 % label_extend(i,j)= 1,否则label_extend(i,j)= 0。 label_extend = [full(sparse(label,1:length(label),1))]'; % 计算预测概率矩阵 P = zeros(m,size(label_extend,2)); for smp = 1:m P(smp,:) = exp(X(smp,:)*theta)/sum(exp(X(smp,:)*theta)); end % 计算代价函数值 cost = -1/m*[label_extend(:)]'*log(P(:))+lambda/2*sum(theta(:).^2); % 计算梯度 thetagrad = -1/m*X'*(label_extend-P)+lambda*theta;
clear clc close all load fisheriris % MATLAB自带数据集 % 对标签重新编号并准备训练/测试数据集 index_train = [1:40,51:90,101:140]; index_test = [41:50,91:100,141:150]; species_train = species(index_train); X=[ones(length(species_train),1),meas(index_train,:)]; label = zeros(size(species_train)); label(strcmp('setosa',species_train)) = 1; label(strcmp('versicolor',species_train)) = 2; label(strcmp('virginica',species_train)) = 3; species_test = species(index_test); X_test = [ones(length(species_test),1),meas(index_test,:)]; lambda = 0.004; % 权重衰减参数Weight decay parameter alpha = 0.1; % 学习速率 MAX_ITR=500; % 最大迭代次数 [theta,test_pre,rate] = mysoftmax_gd(X_test,X,label,lambda,alpha,MAX_ITR)
clear clc close all load MNISTdata % MNIST数据集 % 准备训练/测试数据集 label = labels(1:9000); % 训练集标签 X = [ones(length(label),1),[inputData(:,1:9000)]']; % 训练集输入数据 label_test = labels(9001:end); % 测试集标签 X_test = [ones(length(label_test),1),[inputData(:,9001:end)]']; % 测试输入数据 lambda = 0.004; % 权重衰减参数Weight decay parameter alpha = 0.1; % 学习速率 MAX_ITR=100; % 最大迭代次数 [theta,test_pre,rate] = mysoftmax_gd(X_test,X,label,lambda,alpha,MAX_ITR) index_t = find(label_test==test_pre); % 找出预测正确的样本的位置 rate_test = length(index_t)/length(label_test); % 计算预测精度
水平有限,敬请指正交流。genlovhyy@163.com 。
参考资料:
【1】:http://deeplearning.stanford.edu/wiki/index.php/Softmax%E5%9B%9E%E5%BD%92
【2】:http://blog.csdn.net/acdreamers/article/details/44663305
相关文章推荐
- 解析在main函数之前调用函数以及对设计的作用详解
- 详解Matlab中 sort 函数用法
- java和matlab画多边形闭合折线图示例讲解
- C#调用Matlab生成的dll方法的详细说明
- 简述Matlab中size()函数的用法
- 从java中调用matlab详细介绍
- 稀疏自动编码器 (Sparse Autoencoder)
- 详解Matlab中 sort 函数用法
- 简述Matlab中size()函数的用法
- VC++与Matlab混合编程的快速实现
- Matlab 矩阵运算
- matlab与opencv部分函数的对照
- matlab神经网络工具箱创建神经网络
- Matlab
- MATLAB 入门教程
- matlab函数_连通区域
- MATLAB中函数模式和命令模式的区别
- MATLAB 添加自定义的模块到simulink库浏览器
- Export Figures for LaTeX Writing