您的位置:首页 > 其它

PLA code

2015-11-07 17:54 309 查看

感知机

参考自机器学习基石

# /usr/bin/env python2.7
# encoding=utf-8
import numpy as np
import random,os

def verify(weight,array_x,array_y):
'''
verify prediction
:param weight: itered weight
:param array_x: x
:param array_y: y
:return: true or false
'''
sum_ok = 0

for i in range(len(array_y)):
if sum(weight*array_x[i])*array_y[i] > 0:
sum_ok = sum_ok + 1
print("%d data is classified ok!" % sum_ok)
if sum_ok == len(array_y):
return True
else:
return False

# filepath
ROOT_PATH = os.path.dirname(os.path.realpath(__file__))
## read data
x=[]
y=[]
path_trainfile = '%s/traindata' % ROOT_PATH
with open(path_trainfile,'r') as infile:
for n,lines in enumerate(infile):
if n % 100 == 0:
print("%d lines readed !" % n)
line = lines.strip().split('\t')
y.append(int(line[1]))
num_x = [float(ss) for ss in line[0].split(' ')]
x.append(num_x)

# ======================================PLA=================================================
# init
array_x = np.array(x)
array_y = np.array(y)
weight = np.array([0,0,0,0])
sign_init = -1
num_datasets = len(array_y)
update = 0
#reandom
## random sample
num_random=[random.randint(0,num_datasets-1) for i in range(num_datasets)]
## ordered sample
#num_random = [i for i in range(num_datasets)]

# train
print ("we has %d datasets!" % num_datasets)
for iteration in range(num_datasets):
if iteration == 0:
weight = weight + 0.5*array_y[iteration] * array_x[iteration]
update += 1
else:
random_iter = num_random[iteration]
if sum(weight*array_x[random_iter]) * array_y[random_iter] < 0:
# w = w + 0.5 + x * y
weight = weight + 0.5*array_y[random_iter] * array_x[random_iter]
update += 1
if verify(weight,array_x,array_y):
break
# if update >=100 :
#     break
# print ( weight)

print ("iter : %d " % (iteration+1))
print ("update : %d " % (update))

# test
x=[]
y=[]
count = 0

path_trainfile = '%s/testdata' % ROOT_PATH
with open(path_trainfile,'r') as infile:
for n,lines in enumerate(infile):
if n % 100 == 0:
print("%d lines readed !" % n)
line = lines.strip().split('\t')
y.append(int(line[1]))
num_x = [float(ss) for ss in line[0].split(' ')]
x.append(num_x)

array_x = np.array(x)
array_y = np.array(y)

num_datasets = len(array_y)
for i in range(num_datasets):
# verify
if sum(weight*array_x[i])*array_y[i] > 0:
count = count + 1

rate_error = (num_datasets-count) / float(num_datasets)
print("The error rate is %f" % rate_error)
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: