小白学Tensorflow之Logistic回归
2016-07-25 16:33
453 查看
利用Tensorflow实现Logistic回归
第一,我们先来设计两个函数,使得在后续的程序中不用重复编写相同的代码。
第二,我们带入mnist的数据集,具体方法可以参考官网。
第三,构建损失函数,我们采用softmax和交叉熵来训练模型
完整代码如下:
简书同步更新:http://www.jianshu.com/p/f51f0ca4278c
84cb
第一,我们先来设计两个函数,使得在后续的程序中不用重复编写相同的代码。
def init_weights(shape): return tf.Variable(tf.random_normal(shape, stddev = 0.01)) def model(X, w): return tf.matmul(X, w)
第二,我们带入mnist的数据集,具体方法可以参考官网。
# 导入数据 mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
第三,构建损失函数,我们采用softmax和交叉熵来训练模型
# 构建损失函数,我们采用softmax和交叉熵来训练模型 cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(py_x, Y)) learning_rate = 0.01 train_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
完整代码如下:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
import tensorflow as tf
import input_data
def init_weights(shape): return tf.Variable(tf.random_normal(shape, stddev = 0.01)) def model(X, w): return tf.matmul(X, w)
# 导入数据 mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
# 设置占位符
X = tf.placeholder("float", [None, 784])
Y = tf.placeholder("float", [None, 10])
# 初始化权重
w = init_weights([784, 10])
# 构建模型
py_x = model(X, w)
# 构建损失函数,我们采用softmax和交叉熵来训练模型 cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(py_x, Y)) learning_rate = 0.01 train_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
predict_op = tf.argmax(py_x, 1)
with tf.Session() as sess:
init = tf.initialize_all_variables()
sess.run(init)
for i in xrange(100):
for start, end in zip(range(0, len(trX), 128), range(128, len(trX), 128)):
sess.run(train_op, feed_dict = {X: trX[start:end], Y: trY[start:end]})
print i, np.mean(np.argmax(teY, axis = 1) == sess.run(predict_op, feed_dict = {X: teX, Y: teY}))
简书同步更新:http://www.jianshu.com/p/f51f0ca4278c
84cb
相关文章推荐
- 【HTML5】 Audio/Video全解(集合贴)
- Power Strings --KMP
- spring笔记——ref属性的设定
- 关于linux网络基础记录
- popuWindow和软键盘共存
- 《精通javascript》几个简单的函数
- C++ QQ游戏 连连看外挂 内存挂入门
- 字母金字塔
- maven中Rhino classes (js.jar) not found - Javascript disabled的处理
- 移动WEBAPP开发常规CSS样式总结
- 动画讲解Eclipse常用快捷键
- html的一点动态效果
- 使用cocoaPods导入三方时导入头文件报错
- HDU 5742 It's All In The Mind
- 并发编程之Operation Queue和GCD
- hihoCoder #1077-> RMQ问题再临-线段树
- 安卓下拉刷新开源库对比
- 多态
- Linux定时执行任务
- Linux -trap