Tensorflow 如何使用自己cifar10训练模型检测一张任意的图片

来源:互联网 发布:黑魂3男捏脸数据 编辑:程序博客网 时间:2024/05/16 01:47

Tensorflow如何使用自己cifar10训练模型检测一张任意的图片

研究了cifar10数据集1个月了,终于实现了cifar10训练模型验证一张图片的全部过程,网上给的例子要么caffe实现,要么就是说半截或者是给的代码不全的,实在是无语。自己将我的探究成果写个博客,希望能帮助更多的人少走一些弯路。本博客的代码在官方的例子基础上进行的改版


  • cifar10数据集的简单介绍

共分为10类,具体的分类如下图所示:
60000张图片里面有:
50000张训练样本
10000张测试样本(验证Set)

图片是三通道RGB的彩色图片,大小是32x32像素,3*32*32==3*1024==3072,存储在numpy的时候,前1024位是RGB中的R分量像素值,

中间的1024位是G分量的像素值,最后的1024是B分量的像素值

  • cifar10验证单张图片

1. 验证单张图片,你需要先处理读取的图片,将其处理成 [batch_size, height, width, channels]四维的tensor

2. 调用cifar10.py 中的 inference 函数,对输入图片进行卷积、池化、本地化等操作,之后获取最终的 logits

3. 加载训练模型时保存的恢复点,并对测试图片进行预测


要实现上诉的操作,常用的有类似MNIST使用placeholder的方式,这种方式可以参看一个不完整的样例 How to classify images using tensorflow cifar10 model

我实现的方式是另一种方式,代码如下:

# -*- coding:utf-8 -*-import tensorflow as tffrom tensorflow.python.ops.image_ops_impl import ResizeMethodfrom prettytable import PrettyTable  # PrettyTable使用参看http://www.ithao123.cn/content-2560565.htmlimport cifar10import numpy as npFLAGS = tf.app.flags.FLAGS# 设置存储模型训练结果的路径tf.app.flags.DEFINE_string('checkpoint_dir', '/home/xzy/cifar10_train_xzy',                           """Directory where to read model checkpoints.""")tf.app.flags.DEFINE_string('class_dir', '/home/xzy/cifar10-input/cifar-10-batches-bin/',                           """存储文件batches.meta.txt的目录""")tf.app.flags.DEFINE_string('test_file', '/home/xzy/dog.jpg', """测试用的图片""")IMAGE_SIZE = 24def evaluate_images(images):  # 执行验证    logits = cifar10.inference(images, batch_size=1)    load_trained_model(logits=logits)def load_trained_model(logits):    with tf.Session() as sess:        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)        if ckpt and ckpt.model_checkpoint_path:            # 从训练模型恢复数据            saver = tf.train.Saver()            saver.restore(sess, ckpt.model_checkpoint_path)        else:            print('No checkpoint file found')            return        # 下面两行是预测最有可能的分类        # predict = tf.argmax(logits, 1)        # output = predict.eval()        # 从文件以字符串方式获取10个类标签,使用制表格分割        cifar10_class = np.loadtxt(FLAGS.class_dir + "batches.meta.txt", str, delimiter='\t')        # 预测最大的三个分类        top_k_pred = tf.nn.top_k(logits, k=3)        output = sess.run(top_k_pred)        probability = np.array(output[0]).flatten()  # 取出概率值,将其展成一维数组        index = np.array(output[1]).flatten()        # 使用表格的方式显示        tabel = PrettyTable(["index", "class", "probability"])        tabel.align["index"] = "l"          tabel.padding_width = 1         for i in np.arange(index.size):            tabel.add_row([index[i], cifar10_class[index[i]], probability[i]])        print tabeldef img_read(filename):    if not tf.gfile.Exists(filename):        tf.logging.fatal('File does not exists %s', filename)    image_data = tf.image.convert_image_dtype(tf.image.decode_jpeg(tf.read_file(filename),                                                                   channels=3), dtype=tf.float32)    height = IMAGE_SIZE    width = IMAGE_SIZE    image = tf.image.resize_images(image_data, (height, width), method=ResizeMethod.BILINEAR)    image = tf.expand_dims(image, -1)    image = tf.reshape(image, (1, 24, 24, 3))    return imagedef main(argv=None):  # pylint: disable=unused-argument    filename = FLAGS.test_file    images = img_read(filename)    evaluate_images(images)if __name__ == '__main__':    tf.app.run()

上诉的代码,给定一张狗的图片,显示最大的三类识别的结果,最后是用prettytabel打印出的效果

+-------+-------+-------------+| index | class | probability |+-------+-------+-------------+| 3     |  cat  |   0.530751  || 5     |  dog  |   0.491245  || 2     |  bird |   0.139152  |+-------+-------+-------------+
由于训练的样本设置成100,模型的准确率反而cat的较高,这很正常

本代码的原创之处,自己探索好久尝试写出来的,其他地方绝对没有:

1. 对inference函数加入一个参数,让其默认的值是原来的128,验证单张的时候传入1,这样不会影响原来的测试样本集的验证

2. 使用  np.loadtxt 来读取cifar10的10个类,方便后续知道下标,获取cifar10的种类名称

3. tf.nn.top_k(logits, k=3) 来显示最大的3个类别的概率和index下标

4. 使用prettytable格式化输出


完整的代码,见本人github  https://github.com/xzy256/cifar10_xzy   ,欢迎评论

参考文献

将二进制转换成图片                
tensorflow学习之识别单张图片的实现(python手写数字)           
tensorflow实现embedding展示的简单快速构建例子            
tensorflow使用cifar10模型进行验证单张图片的代码             
tensorfloe模块化函数实现cifar10数据集上测试单张图片     
cifar10获取10个类别的方法        


阅读全文
0 0
原创粉丝点击