junhyukoh的lstm代码解析

来源:互联网 发布:多核优化的游戏 编辑:程序博客网 时间:2024/06/05 09:25

代码地址 https://github.com/junhyukoh/caffe-lstm
此代码是junhyukoh用于生成序列的一个代码,其中有一个例子用于生成一组数。本文主要阐述该例子,并谈谈对lstm的简单理解。本人刚接触DNN两个月,只会caffe一点皮毛,torch,theano等不会使用,caffe下的RNN代码不多,本文是其中一个。据说Jeff Donahue’s 的lstm代码即将并入caffe。

一、lstm生成序列例子
本例中随机生成320个数字序列,作为训练样本。构造LSTM网络(分别构造了15个隐藏节点1层LSTM,50个隐藏节点1层LSTM,7个隐藏节点3层LSTM,23个隐藏节点3层LSTM。http://ethereon.github.io/netscope/#/editor 可以根据prototxt生成网络图,挺漂亮),预测不同长度的序列。经过训练之后,无测试数据,让网络自身输出序列。最终拟合的结果还是不错的。
1层STML
测试时,没有输入,如何输出呢?思想是定位到最开始,一个一个输出。有如下代码:

  for (int i = 0; i < TotalLength; ++i) {     test_clip_blob->mutable_cpu_data()[0] = i > 0; //这句话没看懂啊!    const vector<Blob<float>* >& pred = test_net->ForwardPrefilled();    CHECK_EQ(pred.size(), 1);    CHECK_EQ(pred[0]->count(), 1);    //sequence.cpu_data()[i]是真实数据,*pred[0]->cpu_data()是网络输出数据    log_file << sequence.cpu_data()[i] << " " << *pred[0]->cpu_data() << std::endl;  }

二、lstm简单理解
Andrej Karpathy的直觉RNN http://karpathy.github.io/2015/05/21/rnn-effectiveness/
有一篇好的翻译博客:http://blog.csdn.net/leo_is_ant/article/details/50411020

个人理解,通过网络的展开(序列有几个数字,网络展开就有多少个)学习,网络学习到了输入的数据。也可以理解为一种拟合,用于预测下一个数据或后面几个数据。LSTM是RNN的一个小trick,解决了长时记忆问题,直觉上说,LSTM自带长时记忆。 下一篇将详细介绍LSTM。

0 0
原创粉丝点击