用RNN拟合加法运算
来源:互联网 发布:mysql后加自动增长id 编辑:程序博客网 时间:2024/06/05 08:01
最近在看keras文档的时候看到一个关于RNN的很有意思的应用——用RNN拟合加法运算。看完之后我自己也实现了一下,原版代码在这里https://github.com/fchollet/keras/blob/master/examples/addition_rnn.py
一. 实验描述
用RNN拟合整数的加法运算,其中被加数和加数在区间
二. 实验思路
先从数据讲起,因为被加数和加数在区间{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', ' '}
,所以该实验中一个加法表达式可以表示成一个
[[False False True False False False False False False False False False] [False False True False False False False False False False False False] [False False False True False False False False False False False False] [False False False False False False False False False False True False] [False True False False False False False False False False False False] [False True False False False False False False False False False False] [False False False False False False False False False False False True]]
同理,加法运算的结果可以表示为
[[False False True False False False False False False False False False] [False False False True False False False False False False False False] [False False False False True False False False False False False False] [False False False False False False False False False False False True]]
把加法表达式和结果表示出来后,接下来就可以把数据放到RNN中训练了,我参考了https://github.com/fchollet/keras/blob/master/examples/addition_rnn.py中的网络结构,其结构图如下:
三. 实验数据
本实验随机生成了
四. 实验代码
#encoding: utf-8import numpyfrom keras.models import Sequentialfrom keras.engine.training import slice_Xfrom keras.layers import Activation, TimeDistributed, Dense, RepeatVector, recurrentfrom keras.layers.recurrent import LSTMDATA_SIZE = 50000NUMBER_LEN = 3HIDDEN_SIZE = 128BATCH_SIZE = 128EXPRESSION_LEN = NUMBER_LEN + 1 + NUMBER_LENsymbol_table = '0123456789+ 'index_2_symbol = {}symbol_2_index = {}for i, ch in enumerate(symbol_table): index_2_symbol[i] = ch symbol_2_index[ch] = idef encode(s): feat_mat = numpy.zeros((len(s), len(symbol_table)), dtype = numpy.bool); for i, ch in enumerate(s): feat_mat[i, symbol_2_index[ch]] = 1 return feat_matdef load_data(split): print 'Generating data...' used = set() exp_list = [] ans_list = [] while (len(exp_list) < DATA_SIZE): f = lambda: int(''.join(numpy.random.choice(list('0123456789')) for i in range(numpy.random.randint(1, NUMBER_LEN + 1)))) a, b = f(), f() key = (a, b) if key in used: continue used.add(key) exp = '{}+{}'.format(a, b); exp += ' ' * (EXPRESSION_LEN - len(exp)) ans = str(a + b) ans += ' ' * (NUMBER_LEN + 1 - len(ans)) exp_list.append(exp) ans_list.append(ans) data = numpy.zeros((len(exp_list), EXPRESSION_LEN, len(symbol_table)), dtype = numpy.bool) label = numpy.zeros((len(ans_list), NUMBER_LEN + 1, len(symbol_table)), dtype = numpy.bool) for i, exp in enumerate(exp_list): data[i] = encode(exp) for i, ans in enumerate(ans_list): label[i] = encode(ans) dividing_line = (int)(split * DATA_SIZE) indices = numpy.random.permutation(DATA_SIZE) train_idx, test_idx = indices[:dividing_line], indices[dividing_line:] train_data, test_data = data[train_idx, :], data[test_idx, :] train_label, test_label = label[train_idx,], label[test_idx, ] return (train_data, train_label), (test_data, test_label)def deal(): (train_data, train_label), (test_data, test_label) = load_data(0.7) # print train_data.shape # print train_label.shape # print test_data.shape # print test_label.shape print('Build model...') model = Sequential() model.add(LSTM(HIDDEN_SIZE, input_shape = (EXPRESSION_LEN, len(symbol_table)))) model.add(RepeatVector(NUMBER_LEN + 1)) model.add(LSTM(HIDDEN_SIZE, return_sequences = True)) model.add(TimeDistributed(Dense(len(symbol_table)))) model.add(Activation('softmax')) model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) for i in range(10, 100, 10): print 'iter: ', i model.fit(train_data, train_label, batch_size = BATCH_SIZE, nb_epoch = 10, validation_split = 0.1, verbose = 0) score = model.evaluate(test_data, test_label, batch_size = BATCH_SIZE, verbose = 0) print 'acc: ', score[1]if __name__ == "__main__": deal()
五. 实验结果
iter: 10acc: 0.717766666667iter: 20acc: 0.913150000032iter: 30acc: 0.951416666667iter: 40acc: 0.968233333302iter: 50acc: 0.978iter: 60acc: 0.983066666635iter: 70acc: 0.9844iter: 80acc: 0.985700000032iter: 90acc: 0.981200000032
如有错误请指正
0 0
- 用RNN拟合加法运算
- 用DELPHI实现加法运算
- 用位运算实现加法
- 用位运算实现加法
- 加法运算
- 加法运算
- 【深度学习】python用RNN中LSTM进行正弦函数拟合
- 1.7 用加法实现乘除减运算
- 用发消息方式实现加法运算
- 用C++实现高精度加法运算
- 用位操作实现加法运算
- Java 用位运算改写加法、乘法
- 用位运算实现加法和减法
- 用位运算实现两个整数的加法运算
- onkey 加法运算
- 虚数的加法运算
- 第二课 加法运算
- 位运算实现加法
- iOS开发常用的宏
- 自定义xml解析框架
- activity与fragment中使用OnActivityResult方法
- Netty 初步介绍
- Android系统应用开发(八)ANR应用程序与无响应对话框自定义
- 用RNN拟合加法运算
- 从数据库获取数据插入页面
- mongoDB常用命令
- 数据管理工具中的1130错误的解决方法
- 增量式整理计算机视觉的相关代码
- Path sum: four ways
- golang简单获取上传文件大小的实现代码
- 204.Singleton-单例(容易题)
- sizeof(结构体)