您的位置:首页 > 其它

TensorFlow mnist 数据集练习

2018-02-03 16:45 776 查看
# -*- coding: UTF-8 -*-
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('data/',one_hot=True)
trainimg = mnist.train.images
trainlabel = mnist.train.labels
testimg = mnist.test.images
testimglabel = mnist.test.labels
# print(testimg.shape)
# print(trainlabel.shape)
batch_size = 2**8
#分批次处理
# batch_x,batch_y = mnist.train.next_batch(batch_size)
# print(batch_x.shape)
# print(batch_y.shape)
#None 表示无穷 placehoder 只占位 不占空间
x = tf.placeholder('float', [None, 784])
y = tf.placeholder('float', [None, 10])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
#softmax 回归 分配概率
#http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html
#tf.matmul ->  https://www.jianshu.com/p/19ea2d15eb14 actv = tf.nn.softmax(tf.matmul(x,W)+b)
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(actv),reduction_indices=1))
learn_rate = 0.01
#梯度下降优化器
optm = tf.train.GradientDescentOptimizer(learn_rate).minimize(cost)
#tf.argmax -> http://blog.csdn.net/qq575379110/article/details/70538051 # 1 : 行 0: 列
pred = tf.equal(tf.argmax(actv,1),tf.argmax(y,1))
#http://blog.csdn.net/luoganttcc/article/details/70315538 数据转换
accr = tf.reduce_mean(tf.case(pred, "float"))
init = tf.global_variables_initializer()
training_ecpchs = 50
batch_size = 100
display = 5

with tf.Session() as sess:
sess.run(init)
for epoch in range(training_ecpchs):
avg_cost = 0
num_batch = int(mnist.train.num_examples/batch_size)
for i in range(num_batch):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
sess.run(optm,feed_dict={x:batch_xs,y:batch_ys})
feeds ={x:batch_xs,y:batch_ys}
avg_cost +=sess.run(cost,feed_dict=feeds)/num_batch
if epoch % display == 0 :
feeds_train ={x:batch_xs,y:batch_ys}
feeds_test ={x:mnist.test.images,y:mnist.test.labels}
train_acc = sess.run(accr,feed_dict=feeds_train)
test_acc = sess.run(accr,feed_dict=feeds_test)
print(epoch,training_ecpchs,avg_cost,train_acc,test_acc)
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: