Keras 示例代码 01 addition_rnn.py 结果演示及代码解析

来源:互联网 发布:windows api 移动文件 编辑:程序博客网 时间:2024/06/05 21:15

1-代码简述:

RNN 在语音识别,语言建模,翻译,图片描述等问题上已经取得一定成功。
addition_rnn.py  演示了如何运用LSTM的RNN神经网络模型进行加法的模拟计算。
主要步骤:
# 1步:生成测试用例
# 2步:创建RNN模型
# 3步:训练与预测
整个程序单步耗时6秒,达到99%准确率需要训练50次,耗时5分钟。

1.1-RNN与LSTM说明

详细见:http://colah.github.io/posts/2015-08-Understanding-LSTMs/
中文翻译见:http://www.jianshu.com/p/9dc9f41f0b29

RNN (Recurrent Neural Networks)是包含循环的网络,允许信息的持久化,RNN 可以被看做是同一神经网络的多次复制,每个神经网络模块会把消息传递给下一个。
RNN 的关键点之一就是他们可以用来连接先前的信息到当前的任务上,但当先前信息和当前任务的间隔不断增大时,RNN 会丧失学习到连接如此远的信息的能力。

然而,幸运的是,LSTM 并没有这个问题!
RNN 在语音识别,语言建模,翻译,图片描述等问题上已经取得一定成功,这些成功应用的关键之处就是 LSTM 的使用,这是一种特别的 RNN,比标准的 RNN 在很多的任务上都表现得更好。

2.2-LSTM 网络

Long Short Term 网络—— 一般就叫做 LSTM


LSTM 中的重复模块包含四个交互的层

LSTM 有通过精心设计的称作为“门”的结构来去除或者增加信息到细胞状态的能力。门是一种让信息选择式通过的方法。他们包含一个 sigmoid 神经网络层和一个 pointwise 乘法操作。

LSTM 拥有三个门,来保护和控制细胞状态:第一步是决定我们会从细胞状态中丢弃什么信息。这个决定通过一个称为忘记门层完成。
下一步是确定什么样的新信息被存放在细胞状态中。
这里包含两个部分。第一,sigmoid 层称 “输入门层” 决定什么值我们将要更新。然后,一个 tanh 层创建一个新的候选值向量,\tilde{C}_t,会被加入到状态中。

确定更新的信息

更新细胞状态
 最终,我们需要确定输出什么值。

输出信息

不是所有的 LSTM 都长成一个样子的。实际上,几乎所有包含 LSTM 的论文都采用了微小的变体。

1.3-LSTM 的下一步发展

LSTM 是我们在 RNN 中获得的重要成功。下一步更加重大的突破就是注意力!例如,如果你使用 RNN 来产生一个图片的描述,可能会选择图片的一个部分,根据这部分信息来产生输出的词。

2-运行结果:


--------------------------------------------------
Iteration 16
Train on 45000 samples, validate on 5000 samples
Epoch 1/1
  128/45000 [..............................] - ETA: 6s - loss: 0.1089 - acc: 0.9824
  512/45000 [..............................] - ETA: 6s - loss: 0.1166 - acc: 0.9844
  896/45000 [..............................] - ETA: 5s - loss: 0.1179 - acc: 0.9841
 1280/45000 [..............................] - ETA: 6s - loss: 0.1169 - acc: 0.9838
 1664/45000 [>.............................] - ETA: 6s - loss: 0.1181 - acc: 0.9827
 2048/45000 [>.............................] - ETA: 6s - loss: 0.1161 - acc: 0.9841
 2432/45000 [>.............................] - ETA: 5s - loss: 0.1172 - acc: 0.9840
 2816/45000 [>.............................] - ETA: 5s - loss: 0.1175 - acc: 0.9844
 3200/45000 [=>............................] - ETA: 5s - loss: 0.1195 - acc: 0.9836
 3584/45000 [=>............................] - ETA: 5s - loss: 0.1190 - acc: 0.9840
 3968/45000 [=>............................] - ETA: 5s - loss: 0.1182 - acc: 0.9841
 4352/45000 [=>............................] - ETA: 5s - loss: 0.1184 - acc: 0.9844
 4736/45000 [==>...........................] - ETA: 5s - loss: 0.1182 - acc: 0.9844
 5120/45000 [==>...........................] - ETA: 5s - loss: 0.1177 - acc: 0.9843
 5504/45000 [==>...........................] - ETA: 5s - loss: 0.1178 - acc: 0.9841
 5888/45000 [==>...........................] - ETA: 5s - loss: 0.1181 - acc: 0.9838
 6272/45000 [===>..........................] - ETA: 5s - loss: 0.1187 - acc: 0.9834
 6656/45000 [===>..........................] - ETA: 5s - loss: 0.1185 - acc: 0.9834
 7040/45000 [===>..........................] - ETA: 5s - loss: 0.1188 - acc: 0.9831
 7424/45000 [===>..........................] - ETA: 5s - loss: 0.1189 - acc: 0.9830
 7808/45000 [====>.........................] - ETA: 5s - loss: 0.1190 - acc: 0.9831
 8192/45000 [====>.........................] - ETA: 5s - loss: 0.1189 - acc: 0.9831
 8576/45000 [====>.........................] - ETA: 5s - loss: 0.1195 - acc: 0.9826
 8960/45000 [====>.........................] - ETA: 5s - loss: 0.1197 - acc: 0.9824
 9344/45000 [=====>........................] - ETA: 5s - loss: 0.1199 - acc: 0.9821
 9728/45000 [=====>........................] - ETA: 4s - loss: 0.1201 - acc: 0.9819
10112/45000 [=====>........................] - ETA: 4s - loss: 0.1200 - acc: 0.9820
10496/45000 [=====>........................] - ETA: 4s - loss: 0.1200 - acc: 0.9818
10880/45000 [======>.......................] - ETA: 4s - loss: 0.1200 - acc: 0.9819
11264/45000 [======>.......................] - ETA: 4s - loss: 0.1203 - acc: 0.9817
11648/45000 [======>.......................] - ETA: 4s - loss: 0.1202 - acc: 0.9815
12032/45000 [=======>......................] - ETA: 4s - loss: 0.1202 - acc: 0.9815
12416/45000 [=======>......................] - ETA: 4s - loss: 0.1205 - acc: 0.9812
12800/45000 [=======>......................] - ETA: 4s - loss: 0.1205 - acc: 0.9811
13184/45000 [=======>......................] - ETA: 4s - loss: 0.1211 - acc: 0.9810
13568/45000 [========>.....................] - ETA: 4s - loss: 0.1214 - acc: 0.9809
13952/45000 [========>.....................] - ETA: 4s - loss: 0.1214 - acc: 0.9808
14336/45000 [========>.....................] - ETA: 4s - loss: 0.1214 - acc: 0.9808
14720/45000 [========>.....................] - ETA: 4s - loss: 0.1213 - acc: 0.9808
15104/45000 [=========>....................] - ETA: 4s - loss: 0.1212 - acc: 0.9808
15488/45000 [=========>....................] - ETA: 4s - loss: 0.1211 - acc: 0.9809
15872/45000 [=========>....................] - ETA: 4s - loss: 0.1208 - acc: 0.9811
16256/45000 [=========>....................] - ETA: 4s - loss: 0.1206 - acc: 0.9811
16640/45000 [==========>...................] - ETA: 3s - loss: 0.1205 - acc: 0.9811
17024/45000 [==========>...................] - ETA: 3s - loss: 0.1204 - acc: 0.9812
17408/45000 [==========>...................] - ETA: 3s - loss: 0.1202 - acc: 0.9814
17792/45000 [==========>...................] - ETA: 3s - loss: 0.1200 - acc: 0.9813
18176/45000 [===========>..................] - ETA: 3s - loss: 0.1199 - acc: 0.9812
18560/45000 [===========>..................] - ETA: 3s - loss: 0.1197 - acc: 0.9812
18944/45000 [===========>..................] - ETA: 3s - loss: 0.1195 - acc: 0.9813
19328/45000 [===========>..................] - ETA: 3s - loss: 0.1190 - acc: 0.9815
19712/45000 [============>.................] - ETA: 3s - loss: 0.1190 - acc: 0.9815
20096/45000 [============>.................] - ETA: 3s - loss: 0.1189 - acc: 0.9815
20480/45000 [============>.................] - ETA: 3s - loss: 0.1188 - acc: 0.9816
20864/45000 [============>.................] - ETA: 3s - loss: 0.1187 - acc: 0.9816
21248/45000 [=============>................] - ETA: 3s - loss: 0.1186 - acc: 0.9816
21632/45000 [=============>................] - ETA: 3s - loss: 0.1187 - acc: 0.9815
22016/45000 [=============>................] - ETA: 3s - loss: 0.1187 - acc: 0.9815
22400/45000 [=============>................] - ETA: 3s - loss: 0.1188 - acc: 0.9815
22784/45000 [==============>...............] - ETA: 3s - loss: 0.1188 - acc: 0.9815
23168/45000 [==============>...............] - ETA: 3s - loss: 0.1187 - acc: 0.9815
23552/45000 [==============>...............] - ETA: 3s - loss: 0.1188 - acc: 0.9814
23936/45000 [==============>...............] - ETA: 2s - loss: 0.1189 - acc: 0.9813
24320/45000 [===============>..............] - ETA: 2s - loss: 0.1190 - acc: 0.9813
24704/45000 [===============>..............] - ETA: 2s - loss: 0.1190 - acc: 0.9813
25088/45000 [===============>..............] - ETA: 2s - loss: 0.1188 - acc: 0.9813
25472/45000 [===============>..............] - ETA: 2s - loss: 0.1187 - acc: 0.9814
25856/45000 [================>.............] - ETA: 2s - loss: 0.1186 - acc: 0.9813
26240/45000 [================>.............] - ETA: 2s - loss: 0.1185 - acc: 0.9813
26624/45000 [================>.............] - ETA: 2s - loss: 0.1185 - acc: 0.9813
27008/45000 [=================>............] - ETA: 2s - loss: 0.1184 - acc: 0.9814
27392/45000 [=================>............] - ETA: 2s - loss: 0.1182 - acc: 0.9815
27776/45000 [=================>............] - ETA: 2s - loss: 0.1182 - acc: 0.9815
28160/45000 [=================>............] - ETA: 2s - loss: 0.1182 - acc: 0.9814
28544/45000 [==================>...........] - ETA: 2s - loss: 0.1182 - acc: 0.9814
28928/45000 [==================>...........] - ETA: 2s - loss: 0.1181 - acc: 0.9814
29312/45000 [==================>...........] - ETA: 2s - loss: 0.1181 - acc: 0.9814
29696/45000 [==================>...........] - ETA: 2s - loss: 0.1181 - acc: 0.9813
30080/45000 [===================>..........] - ETA: 2s - loss: 0.1185 - acc: 0.9811
30464/45000 [===================>..........] - ETA: 2s - loss: 0.1187 - acc: 0.9809
30848/45000 [===================>..........] - ETA: 1s - loss: 0.1187 - acc: 0.9809
31232/45000 [===================>..........] - ETA: 1s - loss: 0.1187 - acc: 0.9808
31616/45000 [====================>.........] - ETA: 1s - loss: 0.1186 - acc: 0.9809
32000/45000 [====================>.........] - ETA: 1s - loss: 0.1184 - acc: 0.9809
32384/45000 [====================>.........] - ETA: 1s - loss: 0.1185 - acc: 0.9809
32768/45000 [====================>.........] - ETA: 1s - loss: 0.1184 - acc: 0.9809
33152/45000 [=====================>........] - ETA: 1s - loss: 0.1183 - acc: 0.9809
33536/45000 [=====================>........] - ETA: 1s - loss: 0.1181 - acc: 0.9809
33920/45000 [=====================>........] - ETA: 1s - loss: 0.1181 - acc: 0.9810
34304/45000 [=====================>........] - ETA: 1s - loss: 0.1180 - acc: 0.9809
34688/45000 [======================>.......] - ETA: 1s - loss: 0.1179 - acc: 0.9810
35072/45000 [======================>.......] - ETA: 1s - loss: 0.1178 - acc: 0.9810
35456/45000 [======================>.......] - ETA: 1s - loss: 0.1178 - acc: 0.9811
35840/45000 [======================>.......] - ETA: 1s - loss: 0.1176 - acc: 0.9811
36224/45000 [=======================>......] - ETA: 1s - loss: 0.1175 - acc: 0.9811
36608/45000 [=======================>......] - ETA: 1s - loss: 0.1173 - acc: 0.9812
36992/45000 [=======================>......] - ETA: 1s - loss: 0.1172 - acc: 0.9812
37376/45000 [=======================>......] - ETA: 1s - loss: 0.1172 - acc: 0.9812
37760/45000 [========================>.....] - ETA: 1s - loss: 0.1170 - acc: 0.9812
38144/45000 [========================>.....] - ETA: 0s - loss: 0.1167 - acc: 0.9813
38528/45000 [========================>.....] - ETA: 0s - loss: 0.1165 - acc: 0.9814
38912/45000 [========================>.....] - ETA: 0s - loss: 0.1163 - acc: 0.9815
39296/45000 [=========================>....] - ETA: 0s - loss: 0.1161 - acc: 0.9815
39680/45000 [=========================>....] - ETA: 0s - loss: 0.1161 - acc: 0.9816
40064/45000 [=========================>....] - ETA: 0s - loss: 0.1159 - acc: 0.9816
40448/45000 [=========================>....] - ETA: 0s - loss: 0.1158 - acc: 0.9817
40832/45000 [==========================>...] - ETA: 0s - loss: 0.1157 - acc: 0.9817
41216/45000 [==========================>...] - ETA: 0s - loss: 0.1156 - acc: 0.9817
41600/45000 [==========================>...] - ETA: 0s - loss: 0.1157 - acc: 0.9816
41984/45000 [==========================>...] - ETA: 0s - loss: 0.1155 - acc: 0.9816
42368/45000 [===========================>..] - ETA: 0s - loss: 0.1154 - acc: 0.9817
42752/45000 [===========================>..] - ETA: 0s - loss: 0.1154 - acc: 0.9816
43136/45000 [===========================>..] - ETA: 0s - loss: 0.1155 - acc: 0.9816
43520/45000 [============================>.] - ETA: 0s - loss: 0.1160 - acc: 0.9813
43904/45000 [============================>.] - ETA: 0s - loss: 0.1161 - acc: 0.9812
44288/45000 [============================>.] - ETA: 0s - loss: 0.1163 - acc: 0.9811
44672/45000 [============================>.] - ETA: 0s - loss: 0.1166 - acc: 0.9809
45000/45000 [==============================] - 6s - loss: 0.1167 - acc: 0.9808 - val_loss: 0.1427 - val_acc: 0.9587
Q 637+517
T 1154
☑ 1154
---
Q 441+56 
T 497 
☑ 497 
---
Q 35+615 
T 650 
☑ 650 
---
Q 297+768
T 1065
☒ 1055
---
Q 767+136
T 903 
☑ 903 
---
Q 90+93  
T 183 
☑ 183 
---
Q 329+258
T 587 
☒ 687 
---
Q 966+7  
T 973 
☑ 973 
---
Q 539+81 
T 620 
☑ 620 
---
Q 930+65 
T 995 
☒ 196 
---

3-代码解析:

# -*- coding: utf-8 -*-'''An implementation of sequence to sequence learning for performing additionInput: "535+61"Output: "596"Padding is handled by using a repeated sentinel character (space)Input may optionally be inverted, shown to increase performance in many tasks in:"Learning to Execute"http://arxiv.org/abs/1410.4615and"Sequence to Sequence Learning with Neural Networks"http://papers.nips.cc/paper/5346-sequence-to-sequence-learning-with-neural-networks.pdfTheoretically it introduces shorter term dependencies between source and target.Two digits inverted:+ One layer LSTM (128 HN), 5k training examples = 99% train/test accuracy in 55 epochsThree digits inverted:+ One layer LSTM (128 HN), 50k training examples = 99% train/test accuracy in 100 epochsFour digits inverted:+ One layer LSTM (128 HN), 400k training examples = 99% train/test accuracy in 20 epochsFive digits inverted:+ One layer LSTM (128 HN), 550k training examples = 99% train/test accuracy in 30 epochs执行序列学习来执行加法输入:“535 + 61”输出:“596”填充使用重复的前哨字符(空格)处理输入可以可选地反转,显示为在以下的许多任务中提高性能:学习执行http://arxiv.org/abs/1410.4615序列学习与神经网络序列http://papers.nips.cc/paper/5346-sequence-to-sequence-learning-with-neural-networks.pdf从理论上讲,它会在源和目标之间引入较短的依赖关系。两位数倒数:+一层LSTM128 HN),5k训练样本= 55个时期的99%列车/测试精度三位倒数:+一层LSTM128 HN),50k训练样本= 99%列车/测试精度在100个时期四位倒数:+一层LSTM128 HN),400k训练实例= 20个时期的99%列车/测试精度五位倒数:+一层LSTM128 HN),550k训练样本= 99%列车/测试精度在30个纪元'''from __future__ import print_functionfrom keras.models import Sequentialfrom keras import layersimport numpy as npfrom six.moves import range#如果是keras,前面加上这两句import tensorflow as tffrom keras.backend.tensorflow_backend import set_sessionconfig = tf.ConfigProto(allow_soft_placement=True)gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9)config.gpu_options.allow_growth = True #开始不会给tensorflow全部gpu资源 而是按需增加sess = tf.Session(config=config)class CharacterTable(object):    """Given a set of characters:    + Encode them to a one hot integer representation    + Decode the one hot integer representation to their character output    + Decode a vector of probabilities to their character output    给定一组字符:     +将它们编码为一个热的整数表示     +将一个热整数表达式解码为其字符输出     +将一个概率向量解码为其字符输出    """    def __init__(self, chars):        """Initialize character table.        # Arguments            chars: Characters that can appear in the input.            初始化字符表。         #参数             字符:可以在输入中出现的字符。        """        self.chars = sorted(set(chars))        self.char_indices = dict((c, i) for i, c in enumerate(self.chars))        self.indices_char = dict((i, c) for i, c in enumerate(self.chars))    def encode(self, C, num_rows):        """One hot encode given string C.        # Arguments            num_rows: Number of rows in the returned one hot encoding. This is                used to keep the # of rows for each data the same.           一个热编码给出字符串C.         #参数             num_rows:返回的一个热编码中的行数。 这是                 用于保持每个数据的行数相同。        """        x = np.zeros((num_rows, len(self.chars)))        for i, c in enumerate(C):            x[i, self.char_indices[c]] = 1        return x    def decode(self, x, calc_argmax=True):        if calc_argmax:            x = x.argmax(axis=-1)        return ''.join(self.indices_char[x] for x in x)class colors:    ok = '\033[92m'    fail = '\033[91m'    close = '\033[0m'# Parameters for the model and dataset.#模型和数据集的参数。TRAINING_SIZE = 50000DIGITS = 3INVERT = True# Maximum length of input is 'int + int' (e.g., '345+678'). Maximum length of# int is DIGITS.# 输入的最大长度为'int + int'(例如'345 + 678')。 最大长度# int is DIGITSMAXLEN = DIGITS + 1 + DIGITS# All the numbers, plus sign and space for padding.# 所有的数字,加上填充的符号和空格。chars = '0123456789+ 'ctable = CharacterTable(chars)'''ctable 通过转换后形成两个字典序列:char to indices:{' ': 0, '+': 1, '0': 2, '1': 3, '2': 4, '3': 5, '4': 6, '5': 7, '6': 8, '7': 9, '8': 10, '9': 11}indices to char:{0: ' ', 1: '+', 2: '0', 3: '1', 4: '2', 5: '3', 6: '4', 7: '5', 8: '6', 9: '7', 10: '8', 11: '9'}'''questions = []expected = []seen = set()print('Generating data...')# 1步:生成测试用例# while循环将生成加法问题及其答案,分别保存在 questions  expected 中。毫秒级。while len(questions) < TRAINING_SIZE:    #'0123456789' 随机抽取(1, DIGITS + 1)位数的字符组成数字    f = lambda: int(''.join(np.random.choice(list('0123456789'))                    for i in range(np.random.randint(1, DIGITS + 1))))    a, b = f(), f()  # 随机结果,比如是(3, 1)    # Skip any addition questions we've already seen    # Also skip any such that x+Y == Y+x (hence the sorting).    # 跳过我们已经看到的任何附加问题    # 也可以跳过任何这样的x + Y == Y + x(因此排序)。    key = tuple(sorted((a, b))) # (1, 3)    if key in seen:        continue    seen.add(key)    # Pad the data with spaces such that it is always MAXLEN.    #用空格填充数据,使其始终为MAXLEN    q = '{}+{}'.format(a, b)    #'3+1'    query = q + ' ' * (MAXLEN - len(q)) #'3+1    '    ans = str(a + b)    #'4'    # Answers can be of maximum size DIGITS + 1.    ans += ' ' * (DIGITS + 1 - len(ans)) #'4   '    if INVERT:        # Reverse the query, e.g., '12+345  ' becomes '  543+21'. (Note the        # space used for padding.)        query = query[::-1]    #收集问题和答案    questions.append(query)    expected.append(ans)print('Total addition questions:', len(questions))# 2步:创建RNN模型 并切分训练集、测试集# 创建x,y,保存问题和答案,x:50000个问题,7位字符串,每一位对应12个字符print('Vectorization...')x = np.zeros((len(questions), MAXLEN, len(chars)), dtype=np.bool) #shape (50000, 7, 12)y = np.zeros((len(questions), DIGITS + 1, len(chars)), dtype=np.bool) # shape (50000, 4, 12)for i, sentence in enumerate(questions):    x[i] = ctable.encode(sentence, MAXLEN)for i, sentence in enumerate(expected):    y[i] = ctable.encode(sentence, DIGITS + 1)# Shuffle (x, y) in unison as the later parts of x will almost all be larger# digits.indices = np.arange(len(y)) # array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9,...,49999])np.random.shuffle(indices)  #打乱indices中元素顺序,这两步相当于定义一个随机数组。x = x[indices]     #打乱x,y的元素顺序y = y[indices]# Explicitly set apart 10% for validation data that we never train over.split_at = len(x) - len(x) // 10(x_train, x_val) = x[:split_at], x[split_at:]  #切分训练集和测试集(y_train, y_val) = y[:split_at], y[split_at:]print('Training Data:')print(x_train.shape)print(y_train.shape)print('Validation Data:')print(x_val.shape)print(y_val.shape)# Try replacing GRU, or SimpleRNN.RNN = layers.LSTM  # 选择 LSTM 网络HIDDEN_SIZE = 128BATCH_SIZE = 128LAYERS = 1print('Build model...')model = Sequential()  #初始化# "Encode" the input sequence using an RNN, producing an output of HIDDEN_SIZE.# Note: In a situation where your input sequences have a variable length,# use input_shape=(None, num_feature).model.add(RNN(HIDDEN_SIZE, input_shape=(MAXLEN, len(chars)))) #设置RNN网络# As the decoder RNN's input, repeatedly provide with the last hidden state of# RNN for each time step. Repeat 'DIGITS + 1' times as that's the maximum# length of output, e.g., when DIGITS=3, max output is 999+999=1998.model.add(layers.RepeatVector(DIGITS + 1))  # 重复次数# The decoder RNN could be multiple layers stacked or a single layer.for _ in range(LAYERS):    # By setting return_sequences to True, return not only the last output but    # all the outputs so far in the form of (num_samples, timesteps,    # output_dim). This is necessary as TimeDistributed in the below expects    # the first dimension to be the timesteps.    # #通过将return_sequences设置为True,不仅返回最后一个输出    #      #所有输出到目前为止(num_samplestimesteps    #      output_dim)。 这是必要的,因为TimeDistributed在以下期望    #      #第一个维度作为时间步长。    model.add(RNN(HIDDEN_SIZE, return_sequences=True))# 3步:训练与预测# Apply a dense layer to the every temporal slice of an input. For each of step# of the output sequence, decide which character should be chosen.# #将密集层应用到输入的每个时间片。 对于每一步# #输出序列,决定应选择哪个字符。model.add(layers.TimeDistributed(layers.Dense(len(chars))))model.add(layers.Activation('softmax'))model.compile(loss='categorical_crossentropy',              optimizer='adam',              metrics=['accuracy'])model.summary()# Train the model each generation and show predictions against the validation# dataset.for iteration in range(1, 50):    print()    print('-' * 50)    print('Iteration', iteration)    #训练,期间打印一堆堆的就是由这个计算过程产生    # fit的说明在:D:\Anaconda3\Lib\site-packages\keras\models.py -->fit    # 但真正fit实现的程序是:D:\Anaconda3\Lib\site-packages\keras\engine\training.py --> fit    model.fit(x_train, y_train,  #数组训练数据以及目标数据的Numpy数组,              batch_size=BATCH_SIZE, #梯度更新的样本数。              epochs=1,     #迭代次数在训练数据阵列上。              validation_data=(x_val, y_val)) #要评估的数据模型不对这些数据进行培训。    # Select 10 samples from the validation set at random so we can visualize    # errors. 验证准确性    for i in range(10):        ind = np.random.randint(0, len(x_val))        rowx, rowy = x_val[np.array([ind])], y_val[np.array([ind])]        preds = model.predict_classes(rowx, verbose=0)        q = ctable.decode(rowx[0])        correct = ctable.decode(rowy[0])        guess = ctable.decode(preds[0], calc_argmax=False)        print('Q', q[::-1] if INVERT else q)        print('T', correct)        if correct == guess:            print(colors.ok + '' + colors.close, end=" ")        else:            print(colors.fail + '' + colors.close, end=" ")        print(guess)        print('---')
原创粉丝点击