TensorFlow学习笔记(1):LSTM相关代码
来源:互联网 发布:设计公司logo软件 编辑:程序博客网 时间:2024/06/08 02:46
LSTM是seq2seq模型中经典的子结构,TensorFlow中提供了相应的结构,供我们使用:
tensorflow提供了LSTM实现的一个basic版本,不包含lstm的一些高级扩展,同时也提供了一个标准接口,其中包含了lstm的扩展。分别为:tf.nn.rnn_cell.BasicLSTMCell(), tf.nn.rnn_cell.LSTMCell()
tensorflow中的BasicLSTMCell()是完全按照这个结构进行设计的
#tf.nn.rnn_cell.BasicLSTMCell(num_units, forget_bias, input_size, state_is_tupe=Flase, activation=tanh)cell = tf.nn.rnn_cell.BasicLSTMCell(num_units, forget_bias=1.0, input_size=None, state_is_tupe=Flase, activation=tanh)#num_units:图一中ht的维数,如果num_units=10,那么ht就是10维行向量#forget_bias:还不清楚这个是干嘛的#input_size:[batch_size, max_time, size]。假设要输入一句话,这句话的长度是不固定的,max_time就代表最长的那句话是多长,size表示你打算用多长的向量代表一个word,即embedding_size(embedding_size和size的值不一定要一样)#state_is_tuple:true的话,返回的状态是一个tuple:(c=array([[]]), h=array([[]]):其中c代表Ct的最后时间的输出,h代表Ht最后时间的输出,h是等于最后一个时间的output的#图三向上指的ht称为output#此函数返回一个lstm_cell,即图一中的一个A
如果你想要设计一个多层的LSTM网络,你就会用到tf.nn.rnn_cell.MultiRNNCell(cells, state_is_tuple=False),这里多层的意思上向上堆叠,而不是按时间展开
lstm_cell = tf.nn.rnn_cell.MultiRNNCells(cells, state_is_tuple=False)#cells:一个cell列表,将列表中的cell一个个堆叠起来,如果使用cells=[cell]*4的话,就是四曾,每层cell输入输出结构相同#如果state_is_tuple:则返回的是 n-tuple,其中n=len(cells): tuple:(c=[batch_size, num_units], h=[batch_size,num_units])
这是,网络已经搭好了,tensorflow提供了一个非常方便的方法来生成初始化网络的state
initial_state = lstm_cell.zero_state(batch_size, dtype=)#返回[batch_size, 2*len(cells)],或者[batch_size, s]#这个函数只是用来生成初始化值的
现在进行时间展开,有两种方法:
法一:
使用现成的接口
tf.nn.dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,dtype=None,time_major=False)#此函数会通过,inputs中的max_time将网络按时间展开#cell:将上面的lstm_cell传入就可以#inputs:[batch_size, max_time, size]如果time_major=Flase. [max_time, batch_size, size]如果time_major=True#sequence_length:是一个list,如果你要输入三句话,且三句话的长度分别是5,10,25,那么sequence_length=[5,10,25]#返回:(outputs, states):output,[batch_size, max_time, num_units]如果time_major=False。 [max_time,batch_size,num_units]如果time_major=True。states:[batch_size, 2*len(cells)]或[batch_size,s]#outputs输出的是最上面一层的输出,states保存的是最后一个时间输出的states
第二种方法:outputs = []states = initial_stateswith tf.variable_scope("RNN"): for time_step in range(max_time): if time_step>0:tf.get_variable_scope().reuse_variables()#LSTM同一曾参数共享, (cell_out, state) = lstm_cell(inputs[:,time_step,:], state) outputs.append(cell_out)
tenforflow提供了tf.nn.rnn_cell.GRUCell()构建一个GRU单元
cell = tenforflow提供了tf.nn.rnn_cell.GRUCell(num_units, input_size=None, activation=tanh)#参考lstm cell 使用
阅读全文
0 0
- TensorFlow学习笔记(1):LSTM相关代码
- Tensorflow学习笔记--MNIST LSTM分类器代码
- TensorFlow 笔记(三):多层 LSTM代码详细介绍
- tensorflow笔记:多层LSTM代码分析
- tensorflow笔记:多层LSTM代码分析
- tensorflow笔记:多层LSTM代码分析
- tensorflow学习笔记(六):LSTM 与 GRU
- tensorflow学习笔记(六):LSTM 与 GRU
- TensorFlow学习笔记9:LSTM搭建
- tensorflow学习笔记:LSTM 与 GRU
- 深度学习笔记(1):caffe 添加新层 attention LSTM layer和LSTM layer代码精读
- LSTM学习总结(Based tensorflow)
- tensorflow中lstm学习
- RNN学习笔记(六)-GRU,LSTM 代码实现
- 《白话深度学习与Tensorflow》学习笔记(3)HMM RNN LSTM
- tensorflow学习笔记(三十七):如何自定义LSTM的initial state
- TensorFlow学习笔记(四):手写数字识别之LSTM网络
- 长短期记忆(LSTM)-tensorflow代码实现
- TCP/IP协议简述(转载)
- 如何使用React构建同构(isomorphic)应用
- 关于xls及类似表格中String类型数值数据转化为int及float等等格式
- ASM磁盘组及磁盘 添加、删除
- freemarker常用知识
- TensorFlow学习笔记(1):LSTM相关代码
- js点击轮播或者自动轮播图代码
- 微信菜单获取二维码图片的优化指南——该公众号暂时无法提供服务
- vue2.x自定义组件上使用v-model指令
- Spring Cloud (18) | 给Eureka Server加上安全验证
- 原始的stl文档
- 设计模式-java实现动态代理
- 基本包装类型
- java Zip文件的压缩与解压, 兼容Windows和Linux