TensorFlow实现中文字体分类(四):训练
来源:互联网 发布:postman post传递json 编辑:程序博客网 时间:2024/06/04 01:13
上一篇《TensorFlow实现中文字体分类(三):模型-VGG16》
在训练时用softmax计算交叉熵,容易出现浮点下溢,导致log(0)的计算,这就造成了从此次以后的loss都是Nan,解决方法是限制网络输出范围:tf.log(tf.clip_by_value(pred, 1e-5, 1.0))。学习率过大也会造成Nan,一般出现这种情况的话每次学习率除以10地进行调试。
TensorFlow实现训练大致分两种方法。
最低效的是将data pipeline与训练的graph分割成两部分,然后在session中分次执行。代码示意如左,另一种是将data pipeline写进训练的graph中,让TensorFlow自动多线程处理,代码示意如右。
inputs, outputs = data_pipeline(...)X = tf.placeholder(...)Y = tf.placeholder(...)pred = net(X)loss = loss_func(pred, Y)train_op = optimizer.minimize(loss)trainX, trainY = sess.run([inputs, outputs])sess.run(train_op, feed_dict={X:trainX, Y:trainY})
inputs, outputs = data_pipeline(...)pred = net(inputs)loss = loss_func(pred, outputs)train_op = optimizer.minimize(loss)sess.run(train_op)
然而TensorFlow自动多线程的实现并不是很好,设置batch size 128,iter 1000次测试两种方法,分别耗时665.52s, 654.39s, 基本差别不大。GPU使用率曲线分别如下:
理论上来说,如果把训练数据全部读取到内存,那么只需要在内存与GPU直接通信就行了,但实际上训练集都会非常大,因此最耗时的是在硬盘读取上。所以要获得高效的训练,最好自己实现多线程。在这里我使用Python自带的Queue库和threading库,用4个producer产生数据,一个consumer训练网络,代码如下:
#!/usr/bin/env python# -*- coding: utf-8 -*-import osimport timeimport Queueimport threadingimport tensorflow as tf from dataset.read_data import read_datafrom nnets.vgg import vggos.environ['CUDA_VISIBLE_DEVICES'] = '1'class_num = 2def data_pipline(batch_size): data_batch, annotation = read_data(batch_size) iterator = data_batch.make_initializable_iterator() inputs, outputs = iterator.get_next() with tf.Session() as sess: sess.run(iterator.initializer) for _ in xrange(250): data = sess.run([inputs, outputs]) message.put(data) message.put(None)def train(): inputs = tf.placeholder(tf.float32, shape=[None, 128, 128, 3]) outputs = tf.placeholder(tf.float32, shape=[None, class_num]) tf.summary.image('inputs', inputs, 16) lr = tf.placeholder(tf.float32) keep_prob = tf.placeholder(tf.float32) pred = vgg(inputs, class_num, keep_prob) with tf.name_scope('cross_entropy'): cross_entropy = tf.reduce_mean(-tf.reduce_sum(outputs * tf.log(tf.clip_by_value(pred, 1e-5, 1.0)), reduction_indices=[1])) tf.summary.scalar('cross_entropy', cross_entropy) with tf.name_scope('accuracy'): correct = tf.equal(tf.argmax(pred, 1), tf.argmax(outputs, 1)) accuracy = tf.reduce_mean(tf.cast(correct, tf.float32)) tf.summary.scalar('accuracy', accuracy) with tf.name_scope('optimizer'): optimizer = tf.train.AdamOptimizer(lr).minimize(cross_entropy) merged = tf.summary.merge_all() saver = tf.train.Saver() with tf.Session() as sess: writer = tf.summary.FileWriter('./log/', sess.graph) sess.run(tf.global_variables_initializer()) i, stop_count = 0, 0 st = time.time() while True: i += 1 if stop_count == producer_num: break msg = message.get() if msg is None: stop_count += 1 continue image, label = msg learning_rate = 1e-5 if i < 500 else 1e-5 sess.run(optimizer, feed_dict={inputs:image, outputs:label, lr:learning_rate, keep_prob:0.5}) # if i % 50 == 0: # summary, acc, l = sess.run([merged, accuracy, cross_entropy], feed_dict={inputs:image, outputs:label ,keep_prob:1.0}) # print 'iter:{}, acc:{}, loss:{}'.format(i, acc, l) # writer.add_summary(summary, i) print 'run time: ', time.time() - st saver.save(sess, './models/vgg.ckpt') returnif __name__ == '__main__': BATCH_SIZE = 128 producer_num = 4 message = Queue.Queue(200) for i in xrange(producer_num): producer_name = 'p{}'.format(i) locals()[producer_name] = threading.Thread(target=data_pipline, args=(BATCH_SIZE,)) locals()[producer_name].start() c = threading.Thread(target=train)1 c.start() message.join()
耗时527.11s,下图是GPU使用率,可以看到基本上是100%。取消76-80行的注释会把中间结果写进tensorboard,但会多耗时一些,在执行这个步骤时GPU使用率也会降到0。
在这里只使用Baoli和Xingkai两种字体来做二分类,下图分别是训练时的accuracy和loss
下一篇《TensorFlow实现中文字体分类(五):评估》
- TensorFlow实现中文字体分类(四):训练
- TensorFlow实现中文字体分类(一):预处理
- TensorFlow实现中文字体分类(二):数据流
- TensorFlow实现中文字体分类(五):评估
- TensorFlow实现中文字体分类(三):模型-VGG16
- TensorFlow(四)分类
- tensorflow实现文本分类
- TensorFlow实现图片分类
- Tensorflow 实现二分类
- TensorFlow实现 mnist分类
- 如何重新训练Tensorflow图像分类模型
- TensorFlow学习笔记(四):Tensorflow网络构建和TensorBoard进行训练过程可视化
- 在Android设备上配置TensorFlow(四)无法使用TensorFlow训练新model
- alexnet tensorflow 实现和训练
- TensorFlow简要教程系列(四)TensorFlow实现Softmax回归
- 代码,逻辑回归(logistic_regression)实现mnist分类(TensorFlow实现)
- 利用Tensorflow实现SSD架构model训练(voc2012)
- Tensorflow实现逻辑分类器
- webview如何加载HTML,CSS等语言
- Linux引导流程(一)
- Ajax跨域请求action方法,无法传递及接收cookie信息(应用于系统登录认证及退出)解决方案
- c++,ccf,2017年9月,打酱油试题
- Ubuntu16.04下Anaconda3+tensorflow+Pycharm+Spyder安装与配置
- TensorFlow实现中文字体分类(四):训练
- wps很霸道啊,在用wps开启了一条文档后,用word将不能打开任何文档,即wps在运行的时候,word就用不了了,这只能是wps做的手脚
- 为Word2013文档解决出现乱码的问题
- Git初学 小白笔记(一)
- iOS App上架流程
- oozie 工作流调度引擎总结(一)
- 关于IntelliJ Idea搭建javaweb项目出现的Error filterStart错误解决
- 面试题之“度度熊”
- android调用dialog.hide()引起的输入事件派发错误问题追踪