TensorFlow CNN 测试CIFAR-10数据集
来源:互联网 发布:python程序员出路 编辑:程序博客网 时间:2024/04/29 08:07
本系列文章由 @yhl_leo 出品,转载请注明出处。
文章链接: http://blog.csdn.net/yhl_leo/article/details/50738311
1 CIFAR-10 数据集
CIFAR-10数据集是机器学习中的一个通用的用于图像识别的基础数据集,官网链接为:The CIFAR-10 dataset
下载使用的版本是:
将其解压后(代码中包含自动解压代码),内容为:
2 测试代码
测试代码公布在GitHub:yhlleo
主要代码及作用:
cifar10_input.py
读取本地或者在线下载CIFAR-10的二进制文件格式数据集 cifar10.py
建立CIFAR-10的模型 cifar10_train.py
在CPU或GPU上训练CIFAR-10的模型 cifar10_multi_gpu_train.py
在多个GPU上训练CIFAR-10的模型 cifar10_eval.py
评估CIFAR-10模型的预测性能该部分的代码,介绍了如何使用TensorFlow在CPU和GPU上训练和评估卷积神经网络(convolutional neural network, CNN)。
3 相关网页及教程
更加详细地介绍说明,请浏览网页:Convolutional Neural Networks
中文网站极客学院也有该部分的汉译版:卷积神经网络
代码源自tensorflow官网:tensorflow/models/image/cifar10
4 代码修改说明
GitHub公布代码相对源码(本人的Tensorflow版本还是0.5),主要进行了以下修正:
cifar10.py
# indices = tf.reshape(tf.range(FLAGS.batch_size), [FLAGS.batch_size, 1])indices = tf.reshape(range(FLAGS.batch_size), [FLAGS.batch_size, 1])# orindices = tf.reshape(tf.range(0, FLAGS.batch_size, 1), [FLAGS.batch_size, 1])
此处,源码编译时会出现以下错误:
... File ".../cifar10.py", line 271, in loss indices = tf.reshape(tf.range(FLAGS.batch_size), [FLAGS.batch_size, 1])TypeError: range() takes at least 2 arguments (1 given)
cifar10_input_test.py
#self.assertEqual("%s:%d" % (filename, i), tf.compat.as_text(key))import compat as cp...self.assertEqual("%s:%d" % (filename, i), cp.as_text(key))
不然的话,我测试的时候就会出现这的错误:
AttributeError: 'module' object has no attribute 'compat'
cifar10_train.py
和cifar10_multi_gpu_train.py
源代码里的最大迭代次数max_steps
为1000000
,需要训练几个小时,不忍心折腾我的破笔记本,就改为了20000
。
其他改动,例如导入模块或者文件路径等,都很容易理解,就不列举了~
运行结果,与官网上公布的一致,也不再列举。附上一张运行结果截图:
7 1
- TensorFlow CNN 测试CIFAR-10数据集
- TensorFlow-CNN CIFAR-10数据集 学习
- 【学习笔记】机器学习之用TensorFlow cnn 测试CIFAR-10数据集
- TensorFlow学习笔记---CNN分类CIFAR-10数据集3
- TensorFlow CIFAR-10数据集
- TensorFlow学习笔记(8)----CNN分类CIFAR-10数据集
- TensorFlow应用之进阶版卷积神经网络CNN在CIFAR-10数据集上分类
- TensorFlow深度学习进阶教程:TensorFlow实现CIFAR-10数据集测试的卷积神经网络
- Keras基于Cifar-10数据集的CNN实现
- tensorflow官网Cifar-10改为自己的TFRecords数据集
- tensorflow 卷积神经网络实现CIFAR-10数据集识别
- CNN & Tensorflow 入门——以Cifar-10为例
- Cifar-10数据集的训练与测试
- [keras实战] 小型CNN实现Cifar-10数据集84%准确率
- (CNN笔记整理)类CIFAR数据集的产生
- MINST数据TensorFlow中CNN测试
- CNN训练Cifar-10技巧
- CNN训练Cifar-10技巧
- JAVA常用集合框架用法详解——提高篇
- Android Studio一步步教你集成发布适配
- 在ubuntu中出现Call to undefined function: mysql_connect()
- /usr/bin/ld: cannot find -lgcc
- UltraWebGrid中列固定效果
- TensorFlow CNN 测试CIFAR-10数据集
- android studio使用记录
- iOS蓝牙开发数据实时传输
- 请问在VC++2010中如何连接用Access2010创建好的accdb数据库?
- js获取浏览器信息及屏幕分辨率
- C++的运算符重载
- java学习笔记-继承extends
- Java类初始化顺序
- 关于Server Error in '/' Application.错误