Notes on tensorflow(二)Get started

来源:互联网 发布:第五代网络播放机骗局 编辑:程序博客网 时间:2024/05/18 04:55

Tensor

张量,广义的矩阵,直白点就是多维数组。deep learning的计算过程抽象的讲就是tensor的流动过程:input tensor从一层层的layer流过,到最后变成output。这也是tensorflow的名字由来。

Computational Graph

计算图。将计算过程看成一个有向graph, graph的每个结点代表一个值(tensor或scalar)或一种操作。按官方文档的说法,tf编程的过程大体可以分为两步:
1. 建立graph
2. 运行graph

Session

会话。 graph的运行由session控制。

import tensorflow as tf#build grapha = tf.constant(2)b = tf.placeholder(tf.int32)c = a + b;# run graphwith tf.Session() as sess:    c_val = sess.run([c], feed_dict = {b:20})    print c_val
[22]

三种常用的常量/变量节点

  • tf.constant

常量节点, 值不可变

  • tf.placeholder
placeholder(    dtype,    shape=None,    name=None)

占位符节点, 值也不可变。模型的输入通常使用placeholder。详见https://www.tensorflow.org/api_docs/python/tf/placeholder

  • tf.Variable

变量节点, 值可变。 模型参数(如weight/bias)通常使用Varialble

Variable(<initial-value>, name=<optional-name>)

创建Variable时必须指定初始值,这个初始值决定Variable的shape与dtype。详见https://www.tensorflow.org/api_docs/python/tf/Variable

以线性回归为例

来自官网的例子 https://www.tensorflow.org/get_started/get_started

build graph

# inputx = tf.placeholder(tf.float32)y = tf.placeholder(tf.float32)# parametersW = tf.Variable([.3], tf.float32)b = tf.Variable([-.3], tf.float32)#outputlinear_model = W * x + b# losssquared_deltas = tf.square(linear_model - y)loss = tf.reduce_sum(squared_deltas)# SGDoptimizer = tf.train.GradientDescentOptimizer(0.01)train = optimizer.minimize(loss)

run graph

# training datax_train = [1,2,3,4]y_train = [0,-1,-2,-3]# training loopsess = tf.Session()init = tf.global_variables_initializer()sess.run(init) # 执行参数初始化for i in range(1000):    _, curr_W, curr_b, curr_loss  = sess.run([train, W, b, loss], {x:x_train, y:y_train})    if i %200 == 0:        print("Iteration %d, W: %s, b: %s, loss: %s"%(i, curr_W, curr_b, curr_loss))
Iteration 0, W: [ 0.30000001], b: [-0.30000001], loss: 23.66Iteration 200, W: [-0.99566007], b: [ 0.98724014], loss: 0.000108768Iteration 400, W: [-0.99996454], b: [ 0.99989575], loss: 7.26898e-09Iteration 600, W: [-0.99999911], b: [ 0.99999744], loss: 4.20641e-12Iteration 800, W: [-0.99999911], b: [ 0.99999744], loss: 4.20641e-12
1 0
原创粉丝点击