您的位置:首页 > 其它

Machine Learning Foundations q15

2014-10-14 13:52 246 查看
# -*- coding: utf-8 -*-

import fileinput
import numpy as np

def install(fileName):
xSet=[]
ySet=[]
for line in fileinput.input(fileName):
num_str = line.split()
xSet.append(map(float, num_str[0:-1]))
ySet.append(int(num_str[-1]))
return (np.matrix(xSet),np.matrix(ySet).T)

xSet,ySet=install('hw1_15_train.dat') #从训练集中读取
xSet = np.concatenate((np.ones((xSet.shape[0],1)), xSet), 1)

w = np.matrix(np.zeros(5)) #初始化w,[[ 0.  0.  0.  0.  0.]]

count = 0
while True:
correct_num = 0
for i in xrange(np.shape(xSet)[0]): #从martix里提取行数
xn = xSet[i]
yn = ySet[i]
dot = np.dot(xn,w.T)
if dot*yn <=0:
w += yn*xn
count = count + 1
else:
correct_num = correct_num + 1
if correct_num == 400:
break
else:
print correct_num
print count, w
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: 
相关文章推荐