您的位置:首页 > 其它

TensorFlow试用

2015-11-10 16:10 609 查看
Google发布了开源深度学习工具TensorFlow。

根据官方教程  http://tensorflow.org/tutorials/mnist/beginners/index.md  试用。

操作系统是ubuntu 14.04,64位,python 2.7,已经安装足够的python包。

1. 安装

    1.1 参考文档 http://tensorflow.org/get_started/os_setup.md#binary_installation
    

    1.2 用pip安装,需要用代理,否则连不上,这个是本地ssh到vps出去的。

    sudo pip install https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.5.0-cp27-none-linux_x86_64.whl --proxy http://127.0.0.1:3128
    1.3 注意,我的py2.7已经安装了足够的包,如python-dev,numpy,swig等等。如果遇到缺少相应包的问题,先安装必须的包。

2. 第一个demo,test.py

------------------------------

import tensorflow as tf

hello = tf.constant('Hello, TensorFlow!')

sess = tf.Session()

print sess.run(hello)

a = tf.constant(10)

b = tf.constant(32)

print sess.run(a+b)

------------------------------

3. mnist手写识别

    3.1 下载数据库 

    在http://yann.lecun.com/exdb/mnist/下载上面提到的4个gz文件,放到本地目录如 /tmp/mnist

    3.2 下载input_data.py,放在/home/tim/test目录下

    https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/g3doc/tutorials/mnist/input_data.py
    3.3 在/home/tim/test目录下创建文件test_tensor_flow_mnist.py,内容如下

-----------------------

#!/usr/bin/env python 

import input_data

import tensorflow as tf

mnist = input_data.read_data_sets("/tmp/mnist", one_hot=True)

x = tf.placeholder("float", [None, 784])

W = tf.Variable(tf.zeros([784,10]))

b = tf.Variable(tf.zeros([10]))

y = tf.nn.softmax(tf.matmul(x,W) + b)

y_ = tf.placeholder("float", [None,10])

cross_entropy = -tf.reduce_sum(y_*tf.log(y))

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

init = tf.initialize_all_variables()

sess = tf.Session()

sess.run(init)

for i in range(1000):

    batch_xs, batch_ys = mnist.train.next_batch(100)

    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})

-----------------------

3.4 运行。大概之需要几秒钟时间,输出结果是91%左右。

4. 关于版本

4.1  pip version

pip 1.5.4 from /usr/lib/python2.7/dist-packages (python 2.7)

4.2 已经安装的python包

    有一些是用easy_install安装的,大部分是pip安装的。

pip freeze

Jinja2==2.7.2

MarkupSafe==0.18

MySQL-python==1.2.3

PAM==0.4.2

Pillow==2.3.0

Twisted-Core==13.2.0

Twisted-Web==13.2.0

adium-theme-ubuntu==0.3.4

apt-xapian-index==0.45

argparse==1.2.1

beautifulsoup4==4.2.1

chardet==2.0.1

colorama==0.2.5

command-not-found==0.3

cvxopt==1.1.4

debtagshw==0.1

decorator==3.4.0

defer==1.0.6

dirspec==13.10

duplicity==0.6.23

fp-growth==0.1.2

html5lib==0.999

httplib2==0.8

ipython==1.2.1

joblib==0.7.1

lockfile==0.8

lxml==3.3.3

matplotlib==1.4.3

nose==1.3.1

numexpr==2.2.2

numpy==1.9.2

oauthlib==0.6.1

oneconf==0.3.7

openpyxl==1.7.0

pandas==0.13.1

patsy==0.2.1

pexpect==3.1

piston-mini-client==0.7.5

pyOpenSSL==0.13

pycrypto==2.6.1

pycups==1.9.66

pycurl==7.19.3

pygobject==3.12.0

pygraphviz==1.2

pyparsing==2.0.3

pyserial==2.6

pysmbc==1.0.14.1

python-apt==0.9.3.5

python-dateutil==2.4.2

python-debian==0.1.21-nmu2ubuntu2

pytz==2012c

pyxdg==0.25

pyzmq==14.0.1

reportlab==3.0

requests==2.2.1

scipy==0.13.3

sessioninstaller==0.0.0

simplegeneric==0.8.1

simplejson==3.3.1

six==1.10.0

software-center-aptd-plugins==0.0.0

ssh-import-id==3.21

statsmodels==0.5.0

sympy==0.7.4.1

system-service==0.1.6

tables==3.1.1

tensorflow==0.5.0

tornado==3.1.1

unity-lens-photos==1.0

urllib3==1.7.1

vboxapi==1.0

wheel==0.24.0

wsgiref==0.1.2

xdiagnose==3.6.3build2

xlrd==0.9.2

xlwt==0.7.5

zope.interface==4.0.5

 
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息