机器学习基石 作业4 带Regularizer和Cross Validation的线性回归分类器
2015-10-30 21:09
232 查看
#!/usr/bin/env python # -*- coding: utf-8 -*- """ __title__ = 'main.py' __author__ = 'w1d2s' __mtime__ = '2015/10/30' """ from numpy import * from RidgeReg import * from Validation import * import sys import string def Data_Pretreatment(path): rawData = open(path).readlines() #print rawData dataNum = len(rawData) dataDim = len(rawData[0].strip().split(' ')) - 1 dataIdx = 0 X = zeros([dataNum, dataDim]) Y = zeros(dataNum) print(dataNum, dataDim) for line in rawData: tempList = line.strip().split(' ') Y[dataIdx] = string.atoi(tempList[dataDim]) X[dataIdx, :] = tempList[0: dataDim] dataIdx += 1 return (X, Y) if __name__ == '__main__': Xtrain, Ytrain = Data_Pretreatment('train.dat') Xtest, Ytest = Data_Pretreatment('test.dat') #(Wt, p) = Validate(Xtrain, Ytrain, 120, False) (Wt, p) = Cross_Validate(Xtrain, Ytrain, 5) rate = 10 ** p W = Ridge_Regression(Xtrain, Ytrain, rate) Ein = Err_Counter(Xtrain, Ytrain, W) Eout = Err_Counter(Xtest, Ytest, W) print '** Ein : ' + str(float(Ein)/200) print '** Eout : ' + str(float(Eout)/1000)
#!/usr/bin/env python # -*- coding: utf-8 -*- """ __title__ = 'RidgeReg.py' __author__ = 'w1d2s' __mtime__ = '2015/10/30' """ from numpy import * from scipy import linalg import random def Err_Counter(X, Y, W): (dataSize, dataDim) = X.shape Z = ones([dataSize, dataDim + 1]) Z[:, 1: dataDim + 1] = X ErrCnt = 0 for i in range(0, dataSize): if Y[i] * dot(Z[i, :], W) <= 0: ErrCnt = ErrCnt + 1 return ErrCnt def Ridge_Regression(X, Y, rate): (dataSize, dataDim) = X.shape Z = ones([dataSize, dataDim + 1]) Z[:, 1: dataDim + 1] = X Zt = transpose(Z) ZtZ = dot(Zt, Z) I = identity(len(ZtZ)) P = linalg.inv(ZtZ + rate * I) W = dot(dot(P, Zt), Y) return W
#!/usr/bin/env python # -*- coding: utf-8 -*- """ __title__ = 'Validation.py' __author__ = 'w1d2s' __mtime__ = '2015/10/30' """ from numpy import * from RidgeReg import * def Data_Spliter(X, Y, Num4Train): Xtrain = X[0: Num4Train, :] Ytrain = Y[0: Num4Train] Xval = X[Num4Train: , :] Yval = Y[Num4Train: ] return [Xtrain, Ytrain, Xval, Yval] def Validate(X, Y, Num4Train, IsEt): [Xt, Yt, Xv, Yv] = Data_Spliter(X, Y, Num4Train) minEt = 120 minEv = 80 Wt = zeros([1, Xt.ndim + 1]) p = 0 for pow in range(-10, 3): rate = 10 ** pow W = Ridge_Regression(Xt, Yt, rate) Et = Err_Counter(Xt, Yt, W) Ev = Err_Counter(Xv, Yv, W) if IsEt == True: if Et <= minEt: [Wt, minEt, p] = [W, Et, pow] print '== Et : ' + str(float(Et)/120) print '== log lambda : ' + str(pow) else: if Ev <= minEv: [Wt, minEv, p] = [W, Ev, pow] print '== Ev : ' + str(float(Ev)/80) print '== log lambda : ' + str(pow) Et = Err_Counter(Xt, Yt, Wt) Ev = Err_Counter(Xv, Yv, Wt) print 'log lambda : ' + str(p) print 'Et : ' + str(float(Et)/120) print 'Ev: ' + str(float(Ev)/80) return (Wt, p) def Data_Spliter2(X, Y, folds): dataSize = len(Y) inc = dataSize / folds Xlist = [] Ylist = [] for idx in range(0, dataSize, inc): Xtemp = X[idx: idx + inc, :] Ytemp = Y[idx: idx + inc] Xlist.append(Xtemp) Ylist.append(Ytemp) return (Xlist, Ylist) def Cross_Validate(X, Y, folds): (Xlist, Ylist) = Data_Spliter2(X, Y, folds) (foldSize, foldDim) = Xlist[0].shape Xt = zeros([foldSize * 4, foldDim]) Yt = zeros([foldSize * 4, 1]) Wt = zeros([1, foldDim + 1]) p = 0 minEcv = 10000 for pow in range(-10, 3): rate = 10 ** pow EcvSum = 0 for V in range(0, folds): beg = 0 for idx in range(0, folds): if idx == V: Xv = Xlist[idx] Yv = Ylist[idx] else: Xt[beg: beg + foldSize, :] = Xlist[idx] Ylist[idx].shape = (Ylist[idx].shape[0], 1) Yt[beg: beg + foldSize] = Ylist[idx] beg = beg + foldSize W = Ridge_Regression(Xt, Yt, rate) Ecv = Err_Counter(Xv, Yv, W) EcvSum = EcvSum + Ecv if float(EcvSum)/folds <= minEcv: minEcv = float(EcvSum)/folds (Wt, p) = (W, pow) print 'log lambda: ' + str(p) print 'Ecv : ' + str(minEcv) return (Wt, p)
相关文章推荐
- 插件管理Alcatraz和使用
- 轮播图---可以动态添加图片,(封装成一个函数)
- 子线程与主线程通信的其他方法概述
- Android编译中m、mm、mmm的区别
- Tomcat shutdown无法结束进程的问题
- SDUT 1489 已知中序后序二叉树的先序,深度
- 解决IE9 IE8的跨域 请求问题
- 第6章 接口和实现
- Linux rpm 命令参数使用详解[介绍和应用]
- Swift - 图像控件(UIImageView)的用法
- Redis实现分布式存储
- 关于测试方面一些版本控制及管理工具的安装及使用
- 阿里“三活”数据中心实践经验:没人能做,我们就自己做
- HDU 4609 3-idiots
- [Django后台管理系统]激活Django自带的管理界面
- LeetCode Shortest Palindrome
- 二进制中1的个数
- 零基础ios开发(七 字符串和动态数组的联合使用)
- java设计模式---桥接模式
- json中的数组操作