读MNIST源码(二):tensorflow基础
来源:互联网 发布:数控冲床编程怎么学 编辑:程序博客网 时间:2024/05/29 10:03
读MNIST源码(二):tensorflow基础
- 读MNIST源码二tensorflow基础
- tensorflow运行机理
- 图
- OP
- tensor
- session
- run
- 一个简单的softmax分类器
- 变量和占位符
- reduce_mean
- compute_gradients
- tensorflow运行机理
在本节中将继续以 mnist_softmax.py这个文件为依托,记录一些tensorflow的基础知识
mnist_softmax.py 这个例程实现了一个简单的softmax分类器。在对tensorflow有一定了解的基础上,本文结合官网文档,对分类器实现中的一些要点进行解释
1.tensorflow运行机理
Tensorflow是一个机器学习编程平台(简称TF),以c接口为分界线可以分为前端和后台,前端使用图来描述整个机器学习计算的流程,后台则由高效的c语言负责具体的科学计算。
TF程序执行的过程就是构建图和计算图的过程。所以理解tensorflow程序的关键在于理解图。下面结合tensorflow的官方文档来看图构建和计算的基本介绍
图
图由运行结点和数据结点组成,一个tensorflowf程序默认会构建一个图,这在大多数机器学习过程里是够用的:
A
Graph
contains a set oftf.Operation
objects, which represent units of computation; andtf.Tensor
objects, which represent the units of data that flow between operations.A default
Graph
is always registered, and accessible by callingtf.get_default_graph
. To add an operation to the default graph, simply call one of the functions that defines a newOperation
:
OP
运行结点是operation类的对象(简称op),数据结点是tensor类的对象(以后称tensor)。一个op往往会以0到多个tensor为输入,并以0到多个tensor为输出,op可以由用户自定义构建:
An
Operation
is a node in a TensorFlowGraph
that takes zero or moreTensor
objects as input, and produces zero or moreTensor
objects as output. Objects of typeOperation
are created by calling a Python op constructor (such astf.matmul
) ortf.Graph.create_op
.
tensor
tensor**是op输出数据的一个句柄,它并不实际存储数据,只提供处理这些数据的方法
A
Tensor
is a symbolic handle to one of the outputs of anOperation
. It does not hold the values of that operation’s output, but instead provides a means of computing those values in a TensorFlowtf.Session
.This class has two primary purposes:
- A
Tensor
can be passed as an input to anotherOperation
. This builds a dataflow connection between operations, which enables TensorFlow to execute an entireGraph
that represents a large, multi-step computation.- After the graph has been launched in a session, the value of the
Tensor
can be computed by passing it totf.Session.run
.t.eval()
is a shortcut for callingtf.get_default_session().run(t)
.
张量(tensor)是tensorflow中最为核心和基本的数据结构,张量十分类似于一个多维向量。每一个tensor都有rank属性、shape属性、dtype属性,分别代表着向量的维数,形状(看下面例子),数据类型(浮点、整形等)
3 # a rank 0 tensor; a scalar with shape [][1., 2., 3.] # a rank 1 tensor; a vector with shape [3][[1., 2., 3.], [4., 5., 6.]] # a rank 2 tensor; a matrix with shape [2, 3][[[1., 2., 3.]], [[7., 8., 9.]]] # a rank 3 tensor with shape [2, 1, 3]
session
构建图完成后,就需要通过会话(session)来执行图。session是前端和后台的接口,他封装了op的执行环境和tensor的运算结果。
在初学阶段可以认为session和InteractiveSession是基本一致的。
A
Session
object encapsulates the environment in whichOperation
objects are executed, andTensor
objects are evaluated. For example:
run()
run( fetches, feed_dict=None, options=None, run_metadata=None)
Runs operations and evaluates tensors in
fetches
.This method runs one “step” of TensorFlow computation, by running the necessary graph fragment to execute every
Operation
and evaluate everyTensor
infetches
, substituting the values infeed_dict
for the corresponding input values.
2.一个简单的softmax分类器
这里实现了一个简单的softmax分类器
原理略去不谈,只谈在代码实现中需要注意的几个点,详细过程可参考
变量和占位符
tf.Variable
是变量,一般用来存储参数,变量属于op,
tf.placeholder
是占位符,一般用来存储训练数据,能根据需要更改数据的形状
# Create the model x = tf.placeholder(tf.float32, [None, 784]) W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) y = tf.matmul(x, W) + b #这里用来存放标签数据 y_ = tf.placeholder(tf.float32, [None, 10])
变量在使用前都需要进行初始化
tf.global_variables_initializer().run()
reduce_mean()
代码中loss采用了交叉熵,在计算交叉熵的过程中中用了tf.reduce_mean
这个函数
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ *tf.log(tf.nn.softmax(y)), reduction_indices=[1]))
reduce_mean( input_tensor, axis=None, keep_dims=False, name=None, reduction_indices=None)
这个函数是求平均值的,如果不给第二个参数的话,这个函数会返回tensor中所有元素的平均值,第二个参数指定了从第几个维度上去求平均值,这个维数
举例说明:
# 'x' is [[1., 2.]# [3., 4.]]
x是一个2维数组,分别调用reduce_mean
函数如下:
首先求平均值:
tf.reduce_mean(x) ==> 2.5 #如果不指定第二个参数,那么就在所有的元素中取平均值tf.reduce_mean(x, 0) ==> [2., 3.] #指定第二个参数为0,则第一维的元素取平均值,即每一列求平均值tf.reduce_mean(x, 1) ==> [1.5, 3.5] #指定第二个参数为1,则第二维的元素取平均值,即每一行求平均值
compute_gradients()
这里通过这个函数来看以下训练过程
#调用内部op,实现交叉熵的优化train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
查看这个函数
minimize( loss, global_step=None, var_list=None, gate_gradients=GATE_OP, aggregation_method=None, colocate_gradients_with_ops=False, name=None, grad_loss=None)
注意loss和var_list这两个参数,其中loss是一个包含参数的tensor,而var_list指定了需要优化的参数,如果不作制定,则默认为计算图中收集到的默认参数,即与这个op相连的所有输入参数。
loss: A
Tensor
containing the value to minimize.var_list: Optional list or tuple of
Variable
objects to update to minimizeloss
. Defaults to the list of variables collected in the graph under the keyGraphKeys.TRAINABLE_VARIABLES
.
对于训练过程:
#重复1000次for _ in range(1000): #这里是自定义的函数,随机从数据库取100组数据,详细可参考mnist.py文件 batch_xs, batch_ys = mnist.train.next_batch(100) #run第一个参数是一个op,后边可以给op“喂”参数 sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
下面附上完整的函数代码
def main(_): # Import data mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) # Create the model x = tf.placeholder(tf.float32, [None, 784]) W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) y = tf.matmul(x, W) + b # Define loss and optimizer y_ = tf.placeholder(tf.float32, [None, 10]) # The raw formulation of cross-entropy, # # tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)), # reduction_indices=[1])) # # can be numerically unstable. # # So here we use tf.nn.softmax_cross_entropy_with_logits on the raw # outputs of 'y', and then average across the batch. cross_entropy = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)) train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) sess = tf.InteractiveSession() tf.global_variables_initializer().run() # Train for _ in range(1000): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) # Test trained model correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
- 读MNIST源码(二):tensorflow基础
- TensorFlow 训练 MNIST 数据(二)
- TensorFlow学习(二),深入MNIST
- TensorFlow学习笔记(二)MNIST入门
- Tensorflow入门二 mnist识别(一)
- Tensorflow入门三 mnist识别(二)
- TensorFlow 从入门到精通(二):MNIST 例程源码分析
- 深入浅出TensorFlow(二):TensorFlow解决MNIST问题入门
- 深入浅出TensorFlow(二):TensorFlow解决MNIST问题入门
- 深入浅出TensorFlow(二):TensorFlow解决MNIST问题入门
- tensorflow教程学习二MNIST
- TensorFlow基础(二)
- Tensorflow 基础(二)
- 深度学习框架TensorFlow学习(二)----简单实现Mnist
- tensorflow在mnist集上的使用示例(二)
- TensorFlow学习笔记(二)---MNIST代码分析
- TensorFlow学习笔记(二)MNIST手写数字识别
- <二>、TensorFlow之MNIST机器学习入门(1)
- OLTP && OLAP(DSS)的区别
- Node.js中的常用工具类util
- SQL Server安全管理
- 【理解】汉诺塔问题,新手看这里
- zk的jar包冲突:java.lang.NoSuchMethodError: org.apache.zookeeper.ZooKeeper.getChildren(Ljava/lang/String;
- 读MNIST源码(二):tensorflow基础
- oracle ebs
- 猫狗大战遇到问题
- Linux内核编程 -- 从HelloWord到基于NetFilter的Linux驱动Demo
- 如何关闭Golang中的HTTP连接 How to Close Golang's HTTP connection
- ORA-39365 Error Reported by DataPump Import (IMPDP) When SUPPLEMENTAL_LOG_DATA_MIN Is Set To YES (文档
- WLAN连不上网和以太网连不上网的解决办法
- 【MIUI8_7.6.10】红米NOTE3 全网通 KENZO 高通骁龙650 基于安卓M(Android 6.0)修改精简优化版本
- ubuntu下实用技巧