自顶向下分析一个简单的语音识别系统(九)
来源:互联网 发布:第三方数据公司 编辑:程序博客网 时间:2024/04/30 03:14
前面几回,我们分析完了run_model函数的configuration过程以及数据的输入输出向量的生成,本回我们继续分析一下接下来具体的训练过程。
1.run_training_epochs函数
训练主要是通过这个函数实现的,代码如下所示:
def run_training_epochs(self): train_start = time.time() for epoch in range(self.epochs): # Initialize variables that can be updated #配置信息中读入self.epochs=2 save_dev_model = False stop_training = False is_checkpoint_step, is_validation_step = \ self.validation_and_checkpoint_check(epoch) epoch_start = time.time() self.train_cost, self.train_ler = self.run_batches( self.data_sets.train, is_training=True, decode=False, write_to_file=False, epoch=epoch) epoch_duration = time.time() - epoch_start log = 'Epoch {}/{}, train_cost: {:.3f}, train_ler: {:.3f}, time: {:.2f} sec' logger.info(log.format( epoch + 1, self.epochs, self.train_cost, self.train_ler, epoch_duration)) summary_line = self.sess.run( self.train_ler_op, {self.ler_placeholder: self.train_ler}) self.writer.add_summary(summary_line, epoch) summary_line = self.sess.run( self.train_cost_op, {self.cost_placeholder: self.train_cost}) self.writer.add_summary(summary_line, epoch) # Shuffle the data for the next epoch if self.shuffle_data_after_epoch: np.random.shuffle(self.data_sets.train._txt_files) # Run validation if it was determined to run a validation step if is_validation_step: self.run_validation_step(epoch) if (epoch + 1) == self.epochs or is_checkpoint_step: # save the final model save_path = self.saver.save(self.sess, os.path.join( self.SESSION_DIR, 'model.ckpt'), epoch) logger.info("Model saved: {}".format(save_path)) if save_dev_model: # If the dev set is not improving, # the training is killed to prevent overfitting # And then save the best validation performance model save_path = self.saver.save(self.sess, os.path.join( self.SESSION_DIR, 'model-best.ckpt')) logger.info( "Model with best validation label error rate saved: {}". format(save_path)) if stop_training: break train_duration = time.time() - train_start logger.info('Training complete, total duration: {:.2f} min'.format( train_duration / 60))
第8-9行得到是否是check_step和validation_step;
第13-18行将data_sets.train数据给入run_batches函数中进行训练;
第30-32行调用sess.run进行计算;
第40行表示是否在一次训练之后,打乱训练数据;
第44行表示是否进行validation过程;
第46-60行表示保存训练模型参数;
可以看出,该函数的关键部分是run_batches函数,下面我们开始分析这个函数。
2.run_batches函数
def run_batches(self, dataset, is_training, decode, write_to_file, epoch): n_examples = len(dataset._txt_files) n_batches_per_epoch = int(np.ceil(n_examples / dataset._batch_size)) self.train_cost = 0 self.train_ler = 0 for batch in range(n_batches_per_epoch): # Get next batch of training data (audio features) and transcripts source, source_lengths, sparse_labels = dataset.next_batch() feed = {self.input_tensor: source, self.targets: sparse_labels, self.seq_length: source_lengths} # If the is_training is false, this means straight decoding without computing loss if is_training: # avg_loss is the loss_op, optimizer is the train_op; # running these pushes tensors (data) through graph batch_cost, _ = self.sess.run( [self.avg_loss, self.optimizer], feed) self.train_cost += batch_cost * dataset._batch_size logger.debug('Batch cost: %.2f | Train cost: %.2f', batch_cost, self.train_cost) self.train_ler += self.sess.run(self.ler, feed_dict=feed) * dataset._batch_size logger.debug('Label error rate: %.2f', self.train_ler) # Turn on decode only 1 batch per epoch if decode and batch == 0: d = self.sess.run(self.decoded[0], feed_dict={ self.input_tensor: source, self.targets: sparse_labels, self.seq_length: source_lengths} ) dense_decoded = tf.sparse_tensor_to_dense( d, default_value=-1).eval(session=self.sess) dense_labels = sparse_tuple_to_texts(sparse_labels) # only print a set number of example translations counter = 0 counter_max = 4 if counter < counter_max: for orig, decoded_arr in zip(dense_labels, dense_decoded): # convert to strings decoded_str = ndarray_to_text(decoded_arr) logger.info('Batch {}, file {}'.format(batch, counter)) logger.info('Original: {}'.format(orig)) logger.info('Decoded: {}'.format(decoded_str)) counter += 1 # save out variables for testing self.dense_decoded = dense_decoded self.dense_labels = dense_labels # Metrics mean if is_training: self.train_cost /= n_examples self.train_ler /= n_examples # Populate summary for histograms and distributions in tensorboard self.accuracy, summary_line = self.sess.run( [self.avg_loss, self.summary_op], feed) self.writer.add_summary(summary_line, epoch) return self.train_cost, self.train_ler
第13-15行表示sess.run时指定的feed_dict;
第18-25行表示训练并得到相应的cost;
第31-55行表示decode获得的输出序列。
3.validation_and_checkpoint_check函数
前面提到该函数是为了得到存储模型和验证模型的时间点,具体代码如下:
def validation_and_checkpoint_check(self, epoch): # initially set at False unless indicated to change is_checkpoint_step = False is_validation_step = False # Check if the current epoch is a validation or checkpoint step if (epoch > 0) and ((epoch + 1) != self.epochs): if (epoch + 1) % self.SAVE_MODEL_EPOCH_NUM == 0: is_checkpoint_step = True if (epoch + 1) % self.VALIDATION_EPOCH_NUM == 0: is_validation_step = True return is_checkpoint_step, is_validation_step
SAVE_MODEL_EPOCH_NUM和VALIDATION_EPOCH_NUM均在配置文件中配置,该函数保证在固定的周期对网络模型进行存储和验证。
4.run_validation_step函数
上面可以看出在对模型进行一定次数的训练之后,我们可以调用run_validation_step函数对模型进行验证,具体代码如下:
def run_validation_step(self, epoch): dev_ler = 0 _, dev_ler = self.run_batches(self.data_sets.dev, is_training=False, decode=True, write_to_file=False, epoch=epoch) logger.info('Validation Label Error Rate: {}'.format(dev_ler)) summary_line = self.sess.run( self.dev_ler_op, {self.ler_placeholder: dev_ler}) self.writer.add_summary(summary_line, epoch) if dev_ler < self.min_dev_ler: self.min_dev_ler = dev_ler # average historical LER history_avg_ler = np.mean(self.AVG_VALIDATION_LERS) # if this LER is not better than average of previous epochs, exit if history_avg_ler - dev_ler <= self.CURR_VALIDATION_LER_DIFF: log = "Validation label error rate not improved by more than {:.2%} \ after {} epochs. Exit" warnings.warn(log.format(self.CURR_VALIDATION_LER_DIFF, self.AVG_VALIDATION_LER_EPOCHS)) # save avg validation accuracy in the next slot self.AVG_VALIDATION_LERS[ epoch % self.AVG_VALIDATION_LER_EPOCHS] = dev_ler
由上面代码可以看出验证主要使用self.data_sets.dev中数据,如果验证错误率不比前面的平均错误率高的话,给出相关的warning。
至此,整个训练过程的代码我们都分析完了,还有疑问的是对输出向量decode的时候调用的sparse_tuple_to_texts函数和ndarray_to_text函数还没有分析。我们留待下回细细分解。
- 自顶向下分析一个简单的语音识别系统(九)
- 自顶向下分析一个简单的语音识别系统(一)
- 自顶向下分析一个简单的语音识别系统(二)
- 自顶向下分析一个简单的语音识别系统(三)
- 自顶向下分析一个简单的语音识别系统(四)
- 自顶向下分析一个简单的语音识别系统(五)
- 自顶向下分析一个简单的语音识别系统(六)
- 自顶向下分析一个简单的语音识别系统(七)
- 自顶向下分析一个简单的语音识别系统(八)
- 自顶向下分析一个简单的语音识别系统(十)
- 自顶向下深入分析Netty(九)--ByteBuf
- 自顶向下深入分析Netty(九)--引用计数
- 一个简单的自顶向下语法分析(表达式求值)
- 自顶向下深入分析Netty(九)--ByteBuf源码分析
- 语法分析 自顶向下分析
- 自顶向下的语法分析(修改)
- 自顶向下的语法分析(修改)
- 自顶向下的Splay
- 性能测试
- 零度工作日记
- poj 3469 Dual Core CPU
- Java学习之方法重载
- 蓝桥杯 歌赛新规
- 自顶向下分析一个简单的语音识别系统(九)
- GPU 共享内存地址映射方式
- JAVA设计模式之原型模式
- bzoj 2300: [HAOI2011]防线修建
- HDU1269迷宫城堡(强连通分量)
- 小白学习Rxjava2,从零开始到实战(三) 操作符
- 使用Storm实现实时大数据分析
- 高德地图定位与导航
- Elasticsearch系列篇之创建document