tensorflow学习笔记(三十):tf.gradients 与 tf.stop_gradient()
来源:互联网 发布:闲鱼 淘宝二手ipad 编辑:程序博客网 时间:2024/05/01 20:29
gradient
tensorflow
中有一个计算梯度的函数tf.gradients(ys, xs)
,要注意的是,xs
中的x
必须要与ys
相关,不相关的话,会报错。
代码中定义了两个变量w1
, w2
, 但res
只与w1
相关
#wrongimport tensorflow as tfw1 = tf.Variable([[1,2]])w2 = tf.Variable([[3,4]])res = tf.matmul(w1, [[2],[1]])grads = tf.gradients(res,[w1,w2])with tf.Session() as sess: tf.global_variables_initializer().run() re = sess.run(grads) print(re)
错误信息
TypeError: Fetch argument None has invalid type
# rightimport tensorflow as tfw1 = tf.Variable([[1,2]])w2 = tf.Variable([[3,4]])res = tf.matmul(w1, [[2],[1]])grads = tf.gradients(res,[w1])with tf.Session() as sess: tf.global_variables_initializer().run() re = sess.run(grads) print(re)# [array([[2, 1]], dtype=int32)]
tf.stop_gradient()
阻挡节点BP
的梯度
import tensorflow as tfw1 = tf.Variable(2.0)w2 = tf.Variable(2.0)a = tf.multiply(w1, 3.0)a_stoped = tf.stop_gradient(a)# b=w1*3.0*w2b = tf.multiply(a_stoped, w2)gradients = tf.gradients(b, xs=[w1, w2])print(gradients)#输出#[None, <tf.Tensor 'gradients/Mul_1_grad/Reshape_1:0' shape=() dtype=float32>]
可见,一个节点
被 stop
之后,这个节点上的梯度,就无法再向前BP
了。由于w1
变量的梯度只能来自a
节点,所以,计算梯度返回的是None
。
a = tf.Variable(1.0)b = tf.Variable(1.0)c = tf.add(a, b)c_stoped = tf.stop_gradient(c)d = tf.add(a, b)e = tf.add(c_stoped, d)gradients = tf.gradients(e, xs=[a, b])with tf.Session() as sess: tf.global_variables_initializer().run() print(sess.run(gradients))#输出 [1.0, 1.0]
虽然 c
节点被stop
了,但是a,b
还有从d
传回的梯度,所以还是可以输出梯度值的。
import tensorflow as tfw1 = tf.Variable(2.0)w2 = tf.Variable(2.0)a = tf.multiply(w1, 3.0)a_stoped = tf.stop_gradient(a)# b=w1*3.0*w2b = tf.multiply(a_stoped, w2)opt = tf.train.GradientDescentOptimizer(0.1)gradients = tf.gradients(b, xs=tf.trainable_variables())tf.summary.histogram(gradients[0].name, gradients[0])# 这里会报错,因为gradients[0]是None#其它地方都会运行正常,无论是梯度的计算还是变量的更新。总觉着tensorflow这么设计有点不好,#不如改成流过去的梯度为0train_op = opt.apply_gradients(zip(gradients, tf.trainable_variables()))print(gradients)with tf.Session() as sess: tf.global_variables_initializer().run() print(sess.run(train_op)) print(sess.run([w1, w2]))
0 0
- tensorflow学习笔记(三十):tf.gradients 与 tf.stop_gradient()
- tensorflow 中的tf.gradients 与 tf.stop_gradient() 函数
- opt.compute_gradients() 与 tf.gradients 与 tf.stop_gradient()
- TensorFlow梯度求解tf.gradients
- tensorflow学习笔记(2):常量(tf.constant)与变量(tf.Varialbe)
- tensorflow学习笔记(九):tf.shape()与tensor.get_shape()
- TensorFlow学习笔记(十六)tf.random_normal
- #tensorflow学习笔记#tf.gather
- tensorflow学习笔记--tf.one_hot
- TensorFlow学习笔记之tf.nn.softmax()与tf.nn.softmax_cross_entropy_with_logits的用法
- TensorFlow 学习(一)—— tf.get_variable() vs tf.Variable(),tf.name_scope() vs tf.variable_scope()
- tensorflow学习——tf.floor与tf.train.batch
- TensorFlow学习--tf.add_to_collection与tf.get_collection使用
- TensorFlow学习笔记(5)----TF生成数据的方法
- tensorflow学习笔记(二十六):构建TF代码
- TensorFlow学习笔记(5)----TF生成数据的方法
- TensorFlow学习笔记(五):tf.reshape用法
- tensorflow学习笔记(六):TF.contrib.learn大杂烩
- 个人账目管理系统(一)数据库连接
- hadoop提交作业报错:InvalidAuxServiceException: The auxService:mapreduce_shuffle does not exist
- Spring Cloud构建微服务架构(一)服务注册与发现
- MySQL创建数据表时设定引擎MyISAM/InnoDB
- 8天学通MongoDB——第一天 基础入门
- tensorflow学习笔记(三十):tf.gradients 与 tf.stop_gradient()
- spring boot使用任务调度
- toLocaleString、toString、unshift、values
- 短信平台接口事例
- com.control
- 257. Binary Tree Paths
- ubuntu挂载新硬盘
- 微信公众号登录 Laravel版
- PHP删除Array数组里指定的key