自顶向下分析一个简单的语音识别系统(九)

来源:互联网 发布:第三方数据公司 编辑:程序博客网 时间:2024/04/30 03:14

前面几回,我们分析完了run_model函数的configuration过程以及数据的输入输出向量的生成,本回我们继续分析一下接下来具体的训练过程。

1.run_training_epochs函数

训练主要是通过这个函数实现的,代码如下所示:

    def run_training_epochs(self):        train_start = time.time()        for epoch in range(self.epochs):            # Initialize variables that can be updated            #配置信息中读入self.epochs=2            save_dev_model = False            stop_training = False            is_checkpoint_step, is_validation_step = \                self.validation_and_checkpoint_check(epoch)            epoch_start = time.time()            self.train_cost, self.train_ler = self.run_batches(                self.data_sets.train,                is_training=True,                decode=False,                write_to_file=False,                epoch=epoch)            epoch_duration = time.time() - epoch_start            log = 'Epoch {}/{}, train_cost: {:.3f}, train_ler: {:.3f}, time: {:.2f} sec'            logger.info(log.format(                epoch + 1,                self.epochs,                self.train_cost,                self.train_ler,                epoch_duration))            summary_line = self.sess.run(                self.train_ler_op, {self.ler_placeholder: self.train_ler})            self.writer.add_summary(summary_line, epoch)            summary_line = self.sess.run(                self.train_cost_op, {self.cost_placeholder: self.train_cost})            self.writer.add_summary(summary_line, epoch)            # Shuffle the data for the next epoch            if self.shuffle_data_after_epoch:                np.random.shuffle(self.data_sets.train._txt_files)            # Run validation if it was determined to run a validation step            if is_validation_step:                self.run_validation_step(epoch)            if (epoch + 1) == self.epochs or is_checkpoint_step:                # save the final model                save_path = self.saver.save(self.sess, os.path.join(                    self.SESSION_DIR, 'model.ckpt'), epoch)                logger.info("Model saved: {}".format(save_path))            if save_dev_model:                # If the dev set is not improving,                # the training is killed to prevent overfitting                # And then save the best validation performance model                save_path = self.saver.save(self.sess, os.path.join(                    self.SESSION_DIR, 'model-best.ckpt'))                logger.info(                    "Model with best validation label error rate saved: {}".                    format(save_path))            if stop_training:                break        train_duration = time.time() - train_start        logger.info('Training complete, total duration: {:.2f} min'.format(            train_duration / 60))

第8-9行得到是否是check_step和validation_step;
第13-18行将data_sets.train数据给入run_batches函数中进行训练;
第30-32行调用sess.run进行计算;
第40行表示是否在一次训练之后,打乱训练数据;
第44行表示是否进行validation过程;
第46-60行表示保存训练模型参数;
可以看出,该函数的关键部分是run_batches函数,下面我们开始分析这个函数。

2.run_batches函数

    def run_batches(self, dataset, is_training, decode, write_to_file, epoch):        n_examples = len(dataset._txt_files)        n_batches_per_epoch = int(np.ceil(n_examples / dataset._batch_size))        self.train_cost = 0        self.train_ler = 0        for batch in range(n_batches_per_epoch):            # Get next batch of training data (audio features) and transcripts            source, source_lengths, sparse_labels = dataset.next_batch()            feed = {self.input_tensor: source,                    self.targets: sparse_labels,                    self.seq_length: source_lengths}            # If the is_training is false, this means straight decoding without computing loss            if is_training:                # avg_loss is the loss_op, optimizer is the train_op;                # running these pushes tensors (data) through graph                batch_cost, _ = self.sess.run(                    [self.avg_loss, self.optimizer], feed)                self.train_cost += batch_cost * dataset._batch_size                logger.debug('Batch cost: %.2f | Train cost: %.2f',                             batch_cost, self.train_cost)            self.train_ler += self.sess.run(self.ler, feed_dict=feed) * dataset._batch_size            logger.debug('Label error rate: %.2f', self.train_ler)            # Turn on decode only 1 batch per epoch            if decode and batch == 0:                d = self.sess.run(self.decoded[0], feed_dict={                    self.input_tensor: source,                    self.targets: sparse_labels,                    self.seq_length: source_lengths}                )                dense_decoded = tf.sparse_tensor_to_dense(                    d, default_value=-1).eval(session=self.sess)                dense_labels = sparse_tuple_to_texts(sparse_labels)                # only print a set number of example translations                counter = 0                counter_max = 4                if counter < counter_max:                    for orig, decoded_arr in zip(dense_labels, dense_decoded):                        # convert to strings                        decoded_str = ndarray_to_text(decoded_arr)                        logger.info('Batch {}, file {}'.format(batch, counter))                        logger.info('Original: {}'.format(orig))                        logger.info('Decoded:  {}'.format(decoded_str))                        counter += 1                # save out variables for testing                self.dense_decoded = dense_decoded                self.dense_labels = dense_labels        # Metrics mean        if is_training:            self.train_cost /= n_examples        self.train_ler /= n_examples        # Populate summary for histograms and distributions in tensorboard        self.accuracy, summary_line = self.sess.run(            [self.avg_loss, self.summary_op], feed)        self.writer.add_summary(summary_line, epoch)        return self.train_cost, self.train_ler

第13-15行表示sess.run时指定的feed_dict;
第18-25行表示训练并得到相应的cost;
第31-55行表示decode获得的输出序列。

3.validation_and_checkpoint_check函数

前面提到该函数是为了得到存储模型和验证模型的时间点,具体代码如下:

    def validation_and_checkpoint_check(self, epoch):        # initially set at False unless indicated to change        is_checkpoint_step = False        is_validation_step = False        # Check if the current epoch is a validation or checkpoint step        if (epoch > 0) and ((epoch + 1) != self.epochs):            if (epoch + 1) % self.SAVE_MODEL_EPOCH_NUM == 0:                is_checkpoint_step = True            if (epoch + 1) % self.VALIDATION_EPOCH_NUM == 0:                is_validation_step = True        return is_checkpoint_step, is_validation_step

SAVE_MODEL_EPOCH_NUM和VALIDATION_EPOCH_NUM均在配置文件中配置,该函数保证在固定的周期对网络模型进行存储和验证。

4.run_validation_step函数

上面可以看出在对模型进行一定次数的训练之后,我们可以调用run_validation_step函数对模型进行验证,具体代码如下:

    def run_validation_step(self, epoch):        dev_ler = 0        _, dev_ler = self.run_batches(self.data_sets.dev,                                      is_training=False,                                      decode=True,                                      write_to_file=False,                                      epoch=epoch)        logger.info('Validation Label Error Rate: {}'.format(dev_ler))        summary_line = self.sess.run(            self.dev_ler_op, {self.ler_placeholder: dev_ler})        self.writer.add_summary(summary_line, epoch)        if dev_ler < self.min_dev_ler:            self.min_dev_ler = dev_ler        # average historical LER        history_avg_ler = np.mean(self.AVG_VALIDATION_LERS)        # if this LER is not better than average of previous epochs, exit        if history_avg_ler - dev_ler <= self.CURR_VALIDATION_LER_DIFF:            log = "Validation label error rate not improved by more than {:.2%} \                  after {} epochs. Exit"            warnings.warn(log.format(self.CURR_VALIDATION_LER_DIFF,                                     self.AVG_VALIDATION_LER_EPOCHS))        # save avg validation accuracy in the next slot        self.AVG_VALIDATION_LERS[            epoch % self.AVG_VALIDATION_LER_EPOCHS] = dev_ler

由上面代码可以看出验证主要使用self.data_sets.dev中数据,如果验证错误率不比前面的平均错误率高的话,给出相关的warning。
至此,整个训练过程的代码我们都分析完了,还有疑问的是对输出向量decode的时候调用的sparse_tuple_to_texts函数和ndarray_to_text函数还没有分析。我们留待下回细细分解。

0 0
原创粉丝点击