tensorflow学习笔记(十四):tensorlfow操作gradient
来源:互联网 发布:浙江贰贰网络 看准网 编辑:程序博客网 时间:2024/05/18 02:48
tensorflow中操作gradient-clip
在训练深度神经网络的时候,我们经常会碰到梯度消失
和梯度爆炸
问题,scientists提出了很多方法来解决这些问题,本篇就介绍一下如何在tensorflow中使用clip来address这些问题
train_op = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(loss)
在调用minimize
方法的时候,底层实际干了两件事:
- 计算所有 trainable variables
梯度
- apply them to variables
随后, 在我们 sess.run(train_op)
的时候, 会对 variables
进行更新
clip
那我们如果想处理一下计算完的 gradients
,那该怎么办呢?
官方给出了以下步骤
1. Compute the gradients with compute_gradients(). 计算梯度
2. Process the gradients as you wish. 处理梯度
3. Apply the processed gradients with apply_gradients(). apply处理后的梯度给variables
这样,我们以后在train
的时候就会使用 processed gradient去更新 variable
框架:
# Create an optimizer.optimizer必须和variable在一个设备上声明opt = GradientDescentOptimizer(learning_rate=0.1)# Compute the gradients for a list of variables.grads_and_vars = opt.compute_gradients(loss, <list of variables>)# grads_and_vars is a list of tuples (gradient, variable). Do whatever you# need to the 'gradient' part, for example cap them, etc.capped_grads_and_vars = [(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars]# Ask the optimizer to apply the capped gradients.opt.apply_gradients(capped_grads_and_vars)
例子:
#return a list of trainable variable in you modelparams = tf.trainable_variables()#create an optimizeropt = tf.train.GradientDescentOptimizer(self.learning_rate)#compute gradients for paramsgradients = tf.gradients(loss, params)#process gradientsclipped_gradients, norm = tf.clip_by_global_norm(gradients,max_gradient_norm)train_op = opt.apply_gradients(zip(clipped_gradients, params)))
这时, sess.run(train_op)
就可以进行训练了
0 0
- tensorflow学习笔记(十四):tensorlfow操作gradient
- tensorflow学习笔记(二十四):Bucketing
- tensorflow学习笔记十四:tensorflow中的tf.app.run()
- 图像处理(二十四)Gradient Domain High Dynamic Range Compression学习笔记
- JavaScript学习笔记二十四:操作DOM
- TensorFlow学习笔记(十四)TensorFLow 用mnist数据做classification
- TensorFlow学习笔记(七)feeds操作
- TensorFlow学习笔记(八)add_layer操作
- TensorFlow学习笔记(1)--TensorFlow简介,常用基本操作
- TensorFlow学习笔记--TensorFlow简介,常用基本操作
- tensorflow学习笔记十四:TF官方教程学习 tf.contrib.learn Quickstart
- 深度学习笔记——深度学习框架TensorFlow之MLP(十四)
- Androin学习笔记二十四:wifi连接操作
- TensorFlow学习笔记(二十四)自制TFRecord数据集 读取、显示及代码详解
- TensorFlow 学习(十四)—— contrib
- 学习笔记(十四)
- c++学习笔记十四
- Django 学习笔记(十四)
- 关于TextView的setText()方法报resource not found exception的问题
- 我好久没有写博客了
- android 开发:InputMethodManager内存泄露解决
- 堆和栈的区别
- Makefile详解-程序的编译和链接
- tensorflow学习笔记(十四):tensorlfow操作gradient
- Python基础学习--字符串格式化
- Mybatis学习笔记一:环境搭建以及简单使用
- 开发一个好项目:三、创建数据源,首先创建本地数据源
- 【Java】不使用第三方变量交换两个变量的值
- 使用hibernate传入数据到数据库出现乱码问题解决
- android开发(如何开发一个可以维护的好项目):四 、项目结构
- Makefile详解-介绍
- entity framwork 链接字符串不保存在config文件方法