TensorFlow官方教程学习笔记之2-用于机器学习初学者学习的MNIST数据集(MNIST For ML Beginners)

来源:互联网 发布:淘宝拍卖会的字画真假 编辑:程序博客网 时间:2024/05/17 01:22

1.数据集

MNIST是机器视觉入门级的数据集

2.算法

1)核心
回归(Regression)算法:
这里写图片描述

2)代价函数
交叉熵(cross-entropy):
这里写图片描述

3)优化
梯度下降法

3.代码

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at##     http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# =============================================================================="""A very simple MNIST classifier.See extensive documentation athttps://www.tensorflow.org/get_started/mnist/beginners"""from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport argparseimport sysfrom tensorflow.examples.tutorials.mnist import input_dataimport tensorflow as tfFLAGS = Nonedef main(_):  # Import data  mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)  # Create the model  x = tf.placeholder(tf.float32, [None, 784])  W = tf.Variable(tf.zeros([784, 10]))  b = tf.Variable(tf.zeros([10]))  y = tf.matmul(x, W) + b  # Define loss and optimizer  y_ = tf.placeholder(tf.float32, [None, 10])  # The raw formulation of cross-entropy,  #  #   tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)),  #                                 reduction_indices=[1]))  #  # can be numerically unstable.  #  # So here we use tf.nn.softmax_cross_entropy_with_logits on the raw  # outputs of 'y', and then average across the batch.  cross_entropy = tf.reduce_mean(      tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))  train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)  sess = tf.InteractiveSession()  tf.global_variables_initializer().run()  # Train  for _ in range(1000):    batch_xs, batch_ys = mnist.train.next_batch(100)    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})  # Test trained model  correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))  print(sess.run(accuracy, feed_dict={x: mnist.test.images,                                      y_: mnist.test.labels}))if __name__ == '__main__':  parser = argparse.ArgumentParser()  parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data',                      help='Directory for storing input data')  FLAGS, unparsed = parser.parse_known_args()  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

4.参考资料

[1] MNIST For ML Beginners

阅读全文
0 0
原创粉丝点击