您的位置:首页 > 编程语言 > Python开发

Python Multinomial Logistics 实现MNIST分类

2016-03-12 13:36 561 查看
推导过程有空再整理下来

结果不是特别好,可能训练次数要增加,或者还有其他什么原因

# -*- coding: utf-8 -*-
"""
Created on Fri Mar 11 10:43:54 2016

@author: Administrator
"""

import os
os.getcwd()
os.chdir("D:\\Workspace")

import numpy as np
import struct
import matplotlib.pyplot as plt
from numpy import hstack

images=np.load( "Train_Images.npy" )
labels=np.load("Train_Labels.npy")

images.shape #  (60000, 784)
labels.shape #  (60000, 1)
np.hstack((images,labels)).shape  # (60000, 785)

trainset=np.hstack((images,labels))
del images, labels

epsilon=0.01
"""
# 随机挑选2000个样本进行训练
"""
numImages=trainset.shape[0]
num_examples=2000  # 取样数量
tmp=trainset[np.random.randint(0,numImages,num_examples),]

X_raw=tmp[range(0,num_examples), :-1]
X_raw.shape # (500, 784)
y_raw=tmp[range(0,num_examples),-1]
y_raw.shape # (500,)

del tmp

"""
# X 0-1 化处理
"""
X=[[0 for col in range(X_raw.shape[0])] for row in range(X_raw.shape[1])]
X=np.array(X)
X=X.reshape(X_raw.shape[0], X_raw.shape[1])
X.shape # (500, 784)

for i in range(X.shape[0]):
for j in range(X.shape[1]):
if X_raw[i,j]>0:
X[i,j]=1

# X=X_raw.copy()
X.shape

"""
# y 扩展
"""

y_label=[[0 for col in range(num_examples)] for row in range(10)]
y_label=np.array(y_label)
y_label=y_label.reshape(num_examples, 10)
y_label.shape # (500, 10)

i=0
while i<num_examples:
y_label[i, y_raw[i]]=1
i+=1

# y_label[:10,]
# y_raw[:10]

np.random.seed(0)
theta = np.random.randn(X.shape[1], y_label.shape[1])
theta.shape # (784, 10)

y_pred=y_raw.copy()

for i in range(20000):
z=X.dot(theta)
h=np.exp(z)/np.sum(np.exp(z), axis=1, keepdims=True)

delta=X.T.dot(y_label-h)
theta+= epsilon*delta/num_examples

z1=X.dot(theta)
h1=np.exp(z1)/np.sum(np.exp(z1), axis=1, keepdims=True)
probs= np.sum(np.log(h1)*y_label, axis=1, keepdims=True)
loss=-np.sum(probs)/num_examples

j=0
for j in range(num_examples):
y_pred[j]=np.argmax(h[j,])

error=1.0-np.ndarray.tolist(y_pred-y_raw).count(0)*1.0/num_examples # 258

if i % 100 == 0:
print "Iteration %d with Loss = %f  and Error rate = %f" %(i, loss, error)

z=X.dot(theta)
h=np.exp(z)/np.sum(np.exp(z), axis=1, keepdims=True)
Iteration 17000 with Loss = 0.576084 and Error rate = 0.119500

Iteration 17100 with Loss = 0.573238 and Error rate = 0.118000

Iteration 17200 with Loss = 0.570414 and Error rate = 0.117500

Iteration 17300 with Loss = 0.567612 and Error rate = 0.116000

Iteration 17400 with Loss = 0.564831 and Error rate = 0.115000

Iteration 17500 with Loss = 0.562071 and Error rate = 0.114500

Iteration 17600 with Loss = 0.559332 and Error rate = 0.114000

Iteration 17700 with Loss = 0.556614 and Error rate = 0.113500

Iteration 17800 with Loss = 0.553917 and Error rate = 0.113500

Iteration 17900 with Loss = 0.551239 and Error rate = 0.113500

Iteration 18000 with Loss = 0.548582 and Error rate = 0.113500

Iteration 18100 with Loss = 0.545944 and Error rate = 0.113500

Iteration 18200 with Loss = 0.543326 and Error rate = 0.112500

Iteration 18300 with Loss = 0.540727 and Error rate = 0.112000

Iteration 18400 with Loss = 0.538147 and Error rate = 0.111500

Iteration 18500 with Loss = 0.535586 and Error rate = 0.111500

Iteration 18600 with Loss = 0.533043 and Error rate = 0.111500

Iteration 18700 with Loss = 0.530518 and Error rate = 0.111500

Iteration 18800 with Loss = 0.528012 and Error rate = 0.111500

Iteration 18900 with Loss = 0.525523 and Error rate = 0.111000

Iteration 19000 with Loss = 0.523052 and Error rate = 0.111000

Iteration 19100 with Loss = 0.520598 and Error rate = 0.109500

Iteration 19200 with Loss = 0.518161 and Error rate = 0.108500

Iteration 19300 with Loss = 0.515741 and Error rate = 0.108500

Iteration 19400 with Loss = 0.513338 and Error rate = 0.108000

Iteration 19500 with Loss = 0.510951 and Error rate = 0.107500

Iteration 19600 with Loss = 0.508580 and Error rate = 0.107500

Iteration 19700 with Loss = 0.506226 and Error rate = 0.107000

Iteration 19800 with Loss = 0.503887 and Error rate = 0.106500

Iteration 19900 with Loss = 0.501564 and Error rate = 0.106500
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: