TensorFlow试用

来源:互联网 发布:网络调教男奴手段 编辑:程序博客网 时间:2024/05/16 08:29

Google发布了开源深度学习工具TensorFlow。


根据官方教程  http://tensorflow.org/tutorials/mnist/beginners/index.md  试用。


操作系统是ubuntu 14.04,64位,python 2.7,已经安装足够的python包。



1. 安装

    1.1 参考文档 http://tensorflow.org/get_started/os_setup.md#binary_installation
    
    1.2 用pip安装,需要用代理,否则连不上,这个是本地ssh到vps出去的。

    sudo pip install https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.5.0-cp27-none-linux_x86_64.whl --proxy http://127.0.0.1:3128

    1.3 注意,我的py2.7已经安装了足够的包,如python-dev,numpy,swig等等。如果遇到缺少相应包的问题,先安装必须的包。

2. 第一个demo,test.py
------------------------------
import tensorflow as tf

hello = tf.constant('Hello, TensorFlow!')
sess = tf.Session()
print sess.run(hello)

a = tf.constant(10)
b = tf.constant(32)
print sess.run(a+b)

------------------------------


3. mnist手写识别
    3.1 下载数据库 
    在http://yann.lecun.com/exdb/mnist/下载上面提到的4个gz文件,放到本地目录如 /tmp/mnist

    3.2 下载input_data.py,放在/home/tim/test目录下
    https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/g3doc/tutorials/mnist/input_data.py

    3.3 在/home/tim/test目录下创建文件test_tensor_flow_mnist.py,内容如下
-----------------------
#!/usr/bin/env python 

import input_data
import tensorflow as tf

mnist = input_data.read_data_sets("/tmp/mnist", one_hot=True)

x = tf.placeholder("float", [None, 784])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x,W) + b)
y_ = tf.placeholder("float", [None,10])
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

for i in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
-----------------------

3.4 运行。大概之需要几秒钟时间,输出结果是91%左右。



4. 关于版本

4.1  pip version


pip 1.5.4 from /usr/lib/python2.7/dist-packages (python 2.7)


4.2 已经安装的python包

    有一些是用easy_install安装的,大部分是pip安装的。

pip freeze


Jinja2==2.7.2
MarkupSafe==0.18
MySQL-python==1.2.3
PAM==0.4.2
Pillow==2.3.0
Twisted-Core==13.2.0
Twisted-Web==13.2.0
adium-theme-ubuntu==0.3.4
apt-xapian-index==0.45
argparse==1.2.1
beautifulsoup4==4.2.1
chardet==2.0.1
colorama==0.2.5
command-not-found==0.3
cvxopt==1.1.4
debtagshw==0.1
decorator==3.4.0
defer==1.0.6
dirspec==13.10
duplicity==0.6.23
fp-growth==0.1.2
html5lib==0.999
httplib2==0.8
ipython==1.2.1
joblib==0.7.1
lockfile==0.8
lxml==3.3.3
matplotlib==1.4.3
nose==1.3.1
numexpr==2.2.2
numpy==1.9.2
oauthlib==0.6.1
oneconf==0.3.7
openpyxl==1.7.0
pandas==0.13.1
patsy==0.2.1
pexpect==3.1
piston-mini-client==0.7.5
pyOpenSSL==0.13
pycrypto==2.6.1
pycups==1.9.66
pycurl==7.19.3
pygobject==3.12.0
pygraphviz==1.2
pyparsing==2.0.3
pyserial==2.6
pysmbc==1.0.14.1
python-apt==0.9.3.5
python-dateutil==2.4.2
python-debian==0.1.21-nmu2ubuntu2
pytz==2012c
pyxdg==0.25
pyzmq==14.0.1
reportlab==3.0
requests==2.2.1
scipy==0.13.3
sessioninstaller==0.0.0
simplegeneric==0.8.1
simplejson==3.3.1
six==1.10.0
software-center-aptd-plugins==0.0.0
ssh-import-id==3.21
statsmodels==0.5.0
sympy==0.7.4.1
system-service==0.1.6
tables==3.1.1
tensorflow==0.5.0
tornado==3.1.1
unity-lens-photos==1.0
urllib3==1.7.1
vboxapi==1.0
wheel==0.24.0
wsgiref==0.1.2
xdiagnose==3.6.3build2
xlrd==0.9.2
xlwt==0.7.5
zope.interface==4.0.5

 

1 1