tensorflow学习笔记(三十):tf.gradients 与 tf.stop_gradient()

来源:互联网 发布:闲鱼 淘宝二手ipad 编辑:程序博客网 时间:2024/05/01 20:29

gradient

tensorflow中有一个计算梯度的函数tf.gradients(ys, xs),要注意的是,xs中的x必须要与ys相关,不相关的话,会报错。
代码中定义了两个变量w1w2, 但res只与w1相关

#wrongimport tensorflow as tfw1 = tf.Variable([[1,2]])w2 = tf.Variable([[3,4]])res = tf.matmul(w1, [[2],[1]])grads = tf.gradients(res,[w1,w2])with tf.Session() as sess:    tf.global_variables_initializer().run()    re = sess.run(grads)    print(re)

错误信息
TypeError: Fetch argument None has invalid type

# rightimport tensorflow as tfw1 = tf.Variable([[1,2]])w2 = tf.Variable([[3,4]])res = tf.matmul(w1, [[2],[1]])grads = tf.gradients(res,[w1])with tf.Session() as sess:    tf.global_variables_initializer().run()    re = sess.run(grads)    print(re)#  [array([[2, 1]], dtype=int32)]

tf.stop_gradient()

阻挡节点BP的梯度

import tensorflow as tfw1 = tf.Variable(2.0)w2 = tf.Variable(2.0)a = tf.multiply(w1, 3.0)a_stoped = tf.stop_gradient(a)# b=w1*3.0*w2b = tf.multiply(a_stoped, w2)gradients = tf.gradients(b, xs=[w1, w2])print(gradients)#输出#[None, <tf.Tensor 'gradients/Mul_1_grad/Reshape_1:0' shape=() dtype=float32>]

可见,一个节点stop之后,这个节点上的梯度,就无法再向前BP了。由于w1变量的梯度只能来自a节点,所以,计算梯度返回的是None

a = tf.Variable(1.0)b = tf.Variable(1.0)c = tf.add(a, b)c_stoped = tf.stop_gradient(c)d = tf.add(a, b)e = tf.add(c_stoped, d)gradients = tf.gradients(e, xs=[a, b])with tf.Session() as sess:    tf.global_variables_initializer().run()    print(sess.run(gradients))#输出 [1.0, 1.0]

虽然 c节点被stop了,但是a,b还有从d传回的梯度,所以还是可以输出梯度值的。

import tensorflow as tfw1 = tf.Variable(2.0)w2 = tf.Variable(2.0)a = tf.multiply(w1, 3.0)a_stoped = tf.stop_gradient(a)# b=w1*3.0*w2b = tf.multiply(a_stoped, w2)opt = tf.train.GradientDescentOptimizer(0.1)gradients = tf.gradients(b, xs=tf.trainable_variables())tf.summary.histogram(gradients[0].name, gradients[0])# 这里会报错,因为gradients[0]是None#其它地方都会运行正常,无论是梯度的计算还是变量的更新。总觉着tensorflow这么设计有点不好,#不如改成流过去的梯度为0train_op = opt.apply_gradients(zip(gradients, tf.trainable_variables()))print(gradients)with tf.Session() as sess:    tf.global_variables_initializer().run()    print(sess.run(train_op))    print(sess.run([w1, w2]))
0 0
原创粉丝点击