TensorFlow 遇坑小结
来源:互联网 发布:华盛顿大学知乎 编辑:程序博客网 时间:2024/05/18 22:13
简介
对于TensorFlow和python新手来说,写代码就是坑,遇到各种问题,然而却只能慢慢解决。
GPU溢出OOM问题
报错大致如下:
ResourceExhaustedError (see above for traceback): OOM when allocating tensor with shape[]......
尝试:
我需要跑大概10000张图片,大小为256x256,通道为RGB3通道。最开始请教实验室同学说是因为图片太大,batch_size调小点。然而,
num = 100, batch_size = 32 # 尝试了100张图片,训练了几百轮报上述错误OOMnum = 500, batch_size = 16或8 # 依然如故num = 1000, batch_size = 1 # 这时跑了很久,但还是不行,我想跟batch_size无关了... # 后面我尝试调了图片训练数量和batch_size
猜测:
与batch_size和训练集图片数量无关。
解决:
查看代码,看了几遍发现只有如下部分疑似有问题:
with tf.Session() as sess: for t in range(0, num-batch_size, batch_size) xs_batch, ys_batch = sess.run([X_train[t:t+batch_size],Y_train[t:t+batch_size]])
改动后,如下:
with tf.Session() as sess: # sess.run(tf.global_variables_initializer()) saver.restore(sess, model_path) # writer = tf.summary.FileWriter('./graphs', sess.graph) for i in range(iters): for t in range(0, train_num-batch_size, batch_size): xs_batch, ys_batch = inputs(t, t+batch_size) sess.run(train_step, feed_dict={xs:xs_batch, ys:ys_batch}) if t % 10 == 0: cost = sess.run(cost_function, feed_dict={xs: xs_batch, ys: ys_batch}) print('iters:%s, batch_add:%s, loss:%s' % (i, t, cost)) file_log.write('iters:%s, batch_add:%s, loss:%s \n' % (i, t, cost)) if i % 100 == 0: saver.save(sess, model_path) # writer.close() sess.close()
总结
sess.run()不能频繁使用,尤其是处理大数据集的时候,尽量避免sess.run(),如果只是简单的预测值无所谓。
学习率
问题
当最开始设置较大例如0.5,然后loss却一直不变,是那种波动都不波动的。
猜测
可能与最开始参数随机初始化时,最开始的反向传播梯度无法下降,原因可能与初始化函数的标准差设置有关。
解决办法
逐渐调小学习率,例如我的rate=0.5、0.2、0.12、0.1…直至最后的0.08梯度开始下降,这时候保存参数模型,用这个参数模型再进行训练,这时可以适当增大学习率保证速度。例如保存0.08的参数模型、再将梯度适当调大一点,或者不调,还是0.08。
结束语
后续会补充坑集,OOM只是坑集中的第一个,庆幸解决了。下面是本科毕设项目,正在完成ing。
https://github.com/wangleihitcs/face-enhance
阅读全文
0 0
- TensorFlow 遇坑小结
- Tensorflow学习资料小结
- tensorflow 语法小结
- TENSORFLOW官方文档-MNIST小结
- win10下tensorflow安装中的问题小结
- TensorFlow Variable, Placeholder 以及激励函数学习小结
- http://www.52nlp.cn/tag/tensorflow Andrew Ng (吴恩达) 深度学习课程小结 tensorflow
- tensorflow安装的坑
- tensorflow试用踩坑
- Tensorflow爬过的坑
- Tensorflow 填坑日记
- tensorflow
- TensorFlow
- TensorFlow
- tensorflow
- tensorflow
- tensorflow
- Tensorflow
- java.lang.ClassNotFoundException: Didn't find class "android.hardware.fingerprint.FingerprintManager
- Lombok库的应用
- Gibbs Sampling for Gaussian Mixture Model
- [电脑问题]如何把3.5英寸的硬盘安装到没有硬盘架的新电脑
- TCP/IP、UDP、Http、Socket的区别
- TensorFlow 遇坑小结
- 20171212Link
- SpringBoot 之 普通类获取Spring容器中的bean
- shell获取执行超过1天时间的进程
- 如何配置MySQL远程连接
- 切换手机的输入法
- 我的DBA之路——对内存体系的整理
- Kafka的Log存储解析
- DLL调用(1):C++静态调用DLL