Tensorflow 如何使用自己cifar10训练模型检测一张任意的图片
来源:互联网 发布:黑魂3男捏脸数据 编辑:程序博客网 时间:2024/05/16 01:47
Tensorflow如何使用自己cifar10训练模型检测一张任意的图片
研究了cifar10数据集1个月了,终于实现了cifar10训练模型验证一张图片的全部过程,网上给的例子要么caffe实现,要么就是说半截或者是给的代码不全的,实在是无语。自己将我的探究成果写个博客,希望能帮助更多的人少走一些弯路。本博客的代码在官方的例子基础上进行的改版
cifar10数据集的简单介绍
共分为10类,具体的分类如下图所示:
60000张图片里面有:
50000张训练样本
10000张测试样本(验证Set)图片是三通道RGB的彩色图片,大小是32x32像素,3*32*32==3*1024==3072,存储在numpy的时候,前1024位是RGB中的R分量像素值,
中间的1024位是G分量的像素值,最后的1024是B分量的像素值
cifar10验证单张图片
1. 验证单张图片,你需要先处理读取的图片,将其处理成 [batch_size, height, width, channels]四维的tensor
2. 调用cifar10.py 中的 inference 函数,对输入图片进行卷积、池化、本地化等操作,之后获取最终的 logits
3. 加载训练模型时保存的恢复点,并对测试图片进行预测
要实现上诉的操作,常用的有类似MNIST使用placeholder的方式,这种方式可以参看一个不完整的样例 How to classify images using tensorflow cifar10 model
我实现的方式是另一种方式,代码如下:
# -*- coding:utf-8 -*-import tensorflow as tffrom tensorflow.python.ops.image_ops_impl import ResizeMethodfrom prettytable import PrettyTable # PrettyTable使用参看http://www.ithao123.cn/content-2560565.htmlimport cifar10import numpy as npFLAGS = tf.app.flags.FLAGS# 设置存储模型训练结果的路径tf.app.flags.DEFINE_string('checkpoint_dir', '/home/xzy/cifar10_train_xzy', """Directory where to read model checkpoints.""")tf.app.flags.DEFINE_string('class_dir', '/home/xzy/cifar10-input/cifar-10-batches-bin/', """存储文件batches.meta.txt的目录""")tf.app.flags.DEFINE_string('test_file', '/home/xzy/dog.jpg', """测试用的图片""")IMAGE_SIZE = 24def evaluate_images(images): # 执行验证 logits = cifar10.inference(images, batch_size=1) load_trained_model(logits=logits)def load_trained_model(logits): with tf.Session() as sess: ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: # 从训练模型恢复数据 saver = tf.train.Saver() saver.restore(sess, ckpt.model_checkpoint_path) else: print('No checkpoint file found') return # 下面两行是预测最有可能的分类 # predict = tf.argmax(logits, 1) # output = predict.eval() # 从文件以字符串方式获取10个类标签,使用制表格分割 cifar10_class = np.loadtxt(FLAGS.class_dir + "batches.meta.txt", str, delimiter='\t') # 预测最大的三个分类 top_k_pred = tf.nn.top_k(logits, k=3) output = sess.run(top_k_pred) probability = np.array(output[0]).flatten() # 取出概率值,将其展成一维数组 index = np.array(output[1]).flatten() # 使用表格的方式显示 tabel = PrettyTable(["index", "class", "probability"]) tabel.align["index"] = "l" tabel.padding_width = 1 for i in np.arange(index.size): tabel.add_row([index[i], cifar10_class[index[i]], probability[i]]) print tabeldef img_read(filename): if not tf.gfile.Exists(filename): tf.logging.fatal('File does not exists %s', filename) image_data = tf.image.convert_image_dtype(tf.image.decode_jpeg(tf.read_file(filename), channels=3), dtype=tf.float32) height = IMAGE_SIZE width = IMAGE_SIZE image = tf.image.resize_images(image_data, (height, width), method=ResizeMethod.BILINEAR) image = tf.expand_dims(image, -1) image = tf.reshape(image, (1, 24, 24, 3)) return imagedef main(argv=None): # pylint: disable=unused-argument filename = FLAGS.test_file images = img_read(filename) evaluate_images(images)if __name__ == '__main__': tf.app.run()
上诉的代码,给定一张狗的图片,显示最大的三类识别的结果,最后是用prettytabel打印出的效果+-------+-------+-------------+| index | class | probability |+-------+-------+-------------+| 3 | cat | 0.530751 || 5 | dog | 0.491245 || 2 | bird | 0.139152 |+-------+-------+-------------+由于训练的样本设置成100,模型的准确率反而cat的较高,这很正常本代码的原创之处,自己探索好久尝试写出来的,其他地方绝对没有:
1. 对inference函数加入一个参数,让其默认的值是原来的128,验证单张的时候传入1,这样不会影响原来的测试样本集的验证
2. 使用 np.loadtxt 来读取cifar10的10个类,方便后续知道下标,获取cifar10的种类名称
3. tf.nn.top_k(logits, k=3) 来显示最大的3个类别的概率和index下标
4. 使用prettytable格式化输出
完整的代码,见本人github https://github.com/xzy256/cifar10_xzy ,欢迎评论
参考文献
将二进制转换成图片
tensorflow学习之识别单张图片的实现(python手写数字)
tensorflow实现embedding展示的简单快速构建例子
tensorflow使用cifar10模型进行验证单张图片的代码
tensorfloe模块化函数实现cifar10数据集上测试单张图片
cifar10获取10个类别的方法
- Tensorflow 如何使用自己cifar10训练模型检测一张任意的图片
- 【深度学习】笔记6:使用caffe中的CIFAR10网络模型和自己的图片数据训练自己的模型(步骤详解)
- caffe——cifar10模型训练自己的数据
- tensorflow object_detection 用自己的数据训练目标检测模型Mobilenet
- Tensorflow训练自己的Object Detection模型并进行目标检测
- TensorFlow——训练自己的数据——CIFAR10(一)数据准备
- Faster RCNN 训练自己的检测模型
- Faster RCNN 训练自己的检测模型
- mxnet 使用自己的图片数据训练CNN模型
- 使用lenet模型训练及预测自己的图片数据
- 利用tensorflow训练自己的图片数据(3)——建立网络模型
- VggNet10模型的cifar10深度学习训练
- 深度学习-CAFFE利用CIFAR10网络模型训练自己的图像数据获得模型-3结合caffe中的CIFAR10修改相关配置文件并训练
- 深度学习-CAFFE利用CIFAR10网络模型训练自己的图像数据获得模型-1.制作自己的数据集
- 使用Tensorflow训练自己的分割数据
- 使用tensorflow训练自己的数据
- 深度学习-CAFFE利用CIFAR10网络模型训练自己的图像数据获得模型-4应用生成模型进行预测
- caffe----训练自己的图片caffenet模型
- js获取html内容
- 创建对象的七种方式
- 对ffmpeg的时间戳的理解笔记
- 2017 Multi-University Training Contest
- 计算机网络基础知识
- Tensorflow 如何使用自己cifar10训练模型检测一张任意的图片
- QT 实现自定义的IP地址控件
- Lumen 有问题? 填坑吧
- Vue中Class与Style绑定
- Android四大组件之Activity
- 网狐6.6完整商业版源码架设最新棋牌游戏源码下载
- 推荐系统
- js模拟手机短信对话
- java 用单链表实现队列