您的位置:首页 > 其它

线性回归模型

2016-04-07 19:48 246 查看

1.模型

1.1线性回归概要

回归是*监督学习*中的一种。回归的目的是在给定输入变量的时候,去预测一个或多个*连续*的目标变量的值。其中*线性回归模型*是指模型是可调参数的线性方程,当然最简单的线性回归模型同时也是输入变量的线性方程![这里写图片描述](http://img.blog.csdn.net/20160407205915664)(1),其中$$\x_0=\1$$.不过通常来说我们会对输入变量做某种固定形式的预处理或者说特征抽取,如:x是原始输入变量,而$$ϕ_j\left(x\right)$$是特征,其中$$ϕ_j$$叫基方程,此时模型即为![这里写图片描述](http://img.blog.csdn.net/20160407205824648)(2),当然虽然此时这个模型是输入变量的非线性方程,但模型仍然是线性模型,因为线性回归模型的线性指的是模型是参数的线性方程。


1.2基方程选择

形如1.1中的(2)叫做线性基方程模型,对于基方程$$ϕ_j$$有很多种选择。譬如高斯基方程,logistic sigmoid方程,以及tanh方程等。


2.策略

2.1最大似然和最小平方和误差

对于形如(2)或(1)的模型,由于参数有多种可能值,那到底哪一个才是最好的模型?其实要找到最符合数据的模型,一个最简单的方法就是去找预测值最接近真实值的模型,也就是通过最小化模型的预测值y和真实的数据值t之间的误差去求解模型的参数。通常情况下我们会选取平方和误差![这里写图片描述](http://img.blog.csdn.net/20160407210040867)(3).其实如果我们假设t服从以$$y(\vecx,\vecw)$$为均值,以beta为精度(方差的逆)的正态分布的随机变量,并且假设有一组数据是独立地从这个分布产生的,然后求出其似然方程,并且对其取对数,然后最大化此log likelihood function等价于之前提到的最小化 (3)squared error function。最后我们可以得到对于模型参数的求解:(4)




其中


当然beta值也可以求到



2.2数据说明



我代码所使用的数据形如上图,每一行为一个数据,其中前两列是输入变量的特征值,即输入变量是2维的,最后一列默认是其对应的真实的t值

3.算法

from numpy import *
import matplotlib.pyplot as plt
'''1.读数据'''
def loadDataSet(fileName):
#特征数
numFeat=len(open(fileName).readline().strip().split('\t'))-1
dataMat=[];labelMat=[]#存储输入变量,真实的t值
with open(fileName) as fr:
for line in fr.readlines():
lineFeat=[]
lineList=line.strip().split('\t')
for i in range(numFeat):
lineFeat.append(float(lineList[i]))
labelMat.append(float(lineList[-1]))
dataMat.append(lineFeat)
return dataMat,labelMat

'''2.普通线性回归'''
def standRegres(xArr,yArr):
'''.normal equation for least squares problem解形如2.1中的(4)'''
xMat=mat(xArr)
yMat=mat(yArr).T
xTx=xMat.T*xMat
if linalg.det(xTx)==0.0:#判断可逆与否(行列式是否为零)
print 'This matrix is singular,cannot do inverse'
return
ws=xTx.I*(xMat.T*yMat)#模型参数解
'''衡量模型好坏的方法'''
#1.相关系数衡量相关性(传入参数为行向量)
yHat1=xMat*ws#预测值
print u'一般的线性回归后真实值与预测值的相关性为:'
print corrcoef(yHat1.T,yMat.T)
#2.可视化
fig=plt.figure(1)
ax=fig.add_subplot(111)
#散点图刻画数据点
ax.scatter(xMat[:,1].flatten().A[0],yMat[:,0].flatten().A[0])
#画直线时数据点需要有序
xCopy=xMat.copy()
xCopy.sort(0)
yHat=xCopy*ws
ax.plot(xCopy[:,1],yHat)
plt.show()
#return ws
#我的数据文件是存放在I盘下的PRML目录中的
standRegres(*loadDataSet(r'I:\PRML\ex0.txt'))
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: