TensorFlow mnist 数据集练习
2018-02-03 16:45
776 查看
# -*- coding: UTF-8 -*- import tensorflow as tf import numpy as np import matplotlib.pyplot as plt from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('data/',one_hot=True) trainimg = mnist.train.images trainlabel = mnist.train.labels testimg = mnist.test.images testimglabel = mnist.test.labels # print(testimg.shape) # print(trainlabel.shape) batch_size = 2**8 #分批次处理 # batch_x,batch_y = mnist.train.next_batch(batch_size) # print(batch_x.shape) # print(batch_y.shape) #None 表示无穷 placehoder 只占位 不占空间 x = tf.placeholder('float', [None, 784]) y = tf.placeholder('float', [None, 10]) W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) #softmax 回归 分配概率 #http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html #tf.matmul -> https://www.jianshu.com/p/19ea2d15eb14 actv = tf.nn.softmax(tf.matmul(x,W)+b) cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(actv),reduction_indices=1)) learn_rate = 0.01 #梯度下降优化器 optm = tf.train.GradientDescentOptimizer(learn_rate).minimize(cost) #tf.argmax -> http://blog.csdn.net/qq575379110/article/details/70538051 # 1 : 行 0: 列 pred = tf.equal(tf.argmax(actv,1),tf.argmax(y,1)) #http://blog.csdn.net/luoganttcc/article/details/70315538 数据转换 accr = tf.reduce_mean(tf.case(pred, "float")) init = tf.global_variables_initializer() training_ecpchs = 50 batch_size = 100 display = 5 with tf.Session() as sess: sess.run(init) for epoch in range(training_ecpchs): avg_cost = 0 num_batch = int(mnist.train.num_examples/batch_size) for i in range(num_batch): batch_xs,batch_ys = mnist.train.next_batch(batch_size) sess.run(optm,feed_dict={x:batch_xs,y:batch_ys}) feeds ={x:batch_xs,y:batch_ys} avg_cost +=sess.run(cost,feed_dict=feeds)/num_batch if epoch % display == 0 : feeds_train ={x:batch_xs,y:batch_ys} feeds_test ={x:mnist.test.images,y:mnist.test.labels} train_acc = sess.run(accr,feed_dict=feeds_train) test_acc = sess.run(accr,feed_dict=feeds_test) print(epoch,training_ecpchs,avg_cost,train_acc,test_acc)
相关文章推荐
- Tensorflow MNIST 数据集測试代码入门
- 学习笔记TF056:TensorFlow MNIST,数据集、分类、可视化
- tensorflow MNIST数据集的训练(线性模型)及tensorboard计算结果可视化
- tensorflow mnist实战笔记(二)制作和读取自己的数据集
- Tensorflow学习系列(三): tensorflow mnist数据集如何跑出99+的准确率
- tensorflow mnist数据集手写字识别
- tensorflow tutorials(八):手写数字数据集MNIST介绍
- Tensorflow MNIST 数据集测试代码入门
- TensorFlow mnist数据集路径 MNIST_data 数据下载问题
- Tensorflow MNIST 数据集测试代码入门
- tensorflow MNIST数据集
- tensorflow mnist数据集 cnn demo
- tensorflow MNIST数据集上简单的MLP网络
- Tensorflow mnist basic
- TensorFlow Object Detection API教程——制作自己的数据集
- Android+TensorFlow+CNN+MNIST 手写数字识别实现
- tensorflow imagenet数据集转化
- tensorflow CNN for mnist
- tensorflowxun训练自己的数据集之从tfrecords读取数据
- TensorFlow RNN MNIST字符识别演示快速了解TF RNN核心框架