[TensorFlow]入门学习笔记(3)-图像预处理
来源:互联网 发布:果园拆分系统源码 编辑:程序博客网 时间:2024/06/05 16:46
图像预处理
前言
因为在做目标追踪方面,一直在matlab中写代码,不得不说改代码改的又复杂又难改,优化难做啊。就把图像预处理过程直接放到tensorflow中学习吧。
TFRecord数据格式
对于数据量较小而言,可能一般选择直接将数据加载进内存,然后再分batch输入网络进行训练(tip:使用这种方法时,结合yield 使用更为简洁,之前我一直用的这个方法)。
如果数据量较大,这样的方法就不适用了,因为太耗内存,所以这时最好使用tensorflow提供的队列queue,也就是第二种方法 从文件读取数据。对于一些特定的读取,比如csv文件格式,官网有相关的描述,在这儿我介绍一种比较通用,高效的读取方法(官网介绍的少),即使用tensorflow内定标准格式——TFRecords。这种方式也对图像文件比较友好,主要是图像的都比较大。
不多说,直接贴代码,有些写在注释里了。
读取TFRecord中的样例,以队列的形式。
#-*-encoding:UTF-8 -*-import tensorflow as tf#创建一个reader来读取TFRecord中的样例reader = tf.TFRecordReader()#创建一个临时队列用于维护输入文件列表filename_queue = tf.train.string_input_producer(["path/to/output.tfrecords"])#从文件中读取一个样例。read_up_to函数用于一次性读取多个样例_,serialized_example = reader.read(filename_queue)#解析读入的样例。如果需要解析多个样例,用parse_examplefeatures = tf.parse_single_example(serialized_example,features={ #TensorFlow提供了两种属性解析方法 #1.tf.FixedLenFeature解析结果为Tensor. #2.tf.VarLenFeature,这种解析方法解析为一个SparseTensor,用于处理稀疏矩阵 #格式需要与上面的写入数据的格式相一致 'image_raw':tf.FixedLenFeature([],tf.string), 'pixels':tf.FixedLenFeature([],tf.int64), 'label':tf.FixedLenFeature([],tf.int64),})#tf.decode_raw将字符串解析为对应的像素数组images = tf.decode_raw(features['image_raw'],tf.uint8)labels = tf.cast(features['label'],tf.int32)pixels = tf.cast(features['pixels'],tf.int32)sess = tf.Session()#启动多线程处理输入数据coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess,coord=coord)#每次读取一个样例,当所有样例读取完之后,从头读取。for i in range(10): image,label,pixel = sess.run([images,labels,pixels])
那么如何将已经存在的数据存储为TFRecord呢?
以mnist的数据集为例,
- 获取你的数据
- 将数据填入到Example协议内存块(protocol buffer)
- 将数据中的信息feature写入这个结构
- 之后,通过tf.python_io.TFRecordWriter 写入到TFRecords文件
# -*- coding: UTF-8 -*import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_dataimport numpy as np#生成整数型的属性def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))#生成字符串的属性def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))mnist = input_data.read_data_sets("../MNIST_data/", dtype=tf.uint8,one_hot=True)images = mnist.train.imageslabels = mnist.train.labelspixels = images.shape[1] #像素值可以作为Example的一个属性num_examples = mnist.train.num_examples#输出TFRecord文件的地址filename = "path/to/output.tfrecords"#创建一个writer来写TFRecord文件writer = tf.python_io.TFRecordWriter(filename)for index in range(num_examples): #将图像转化为字符串 image_raw = images[index].tostring() #将一个样例转化为Example Protocol Buffer,并将所有信息写入这个结构 example = tf.train.Example(features = tf.train.Features(feature={ 'pixels':_int64_feature(pixels), 'label':_int64_feature(np.argmax(labels[index])), 'image_raw':_bytes_feature(image_raw) })) #写入TFRecord writer.write(example.SerializeToString())writer.close()
图像预处理
图像预处理有很多过程。这里只介绍函数。方便使用。
- 图像读取原始
- tf.gfile.FastGFile().read()
- 图像格式的编码解码 :图像不直接记录图像上的不同位置,不同颜色的亮度。而是记录压缩编码之后的结果。所以要还原成三维矩阵,需要解码。
- tf.image.decode_jpeg()
- tf.image.encode_jpeg()
- 转换函数 tf.image.convert_image_dtype
- 图像大小调整
- tf.image.resize_images(image,[size],method)
- method 0:双线性插值 1:最近邻居法 2: 双三次插值法 3:面积差值法
- tf.image.resize_image_with_crop_pad 自动裁剪或者填充
- 图像翻转
- tf.image.flip_up_down()
- tf.image.filp_left_right()
- tf.image.transpose_image()
- 图像色彩调整
- 亮度调整 tf.image.adjust_brightness(image,brightness)
- 随机亮度调整 tf.image.random_brightness(image,max_delta)
- 同理调整,tf.image.adjust_contrast,tf.image.adjust_hue,tf.image.
saturation. - 图像标准化 tf.image.per_image_whitening(image)
标注框
- tf.image.draw_bounding_boxes(batch,boxes) 这个函数要求图像矩阵的数字为实数,而且输入是一个batch的数据,即多张图像组成的四维矩阵,所以将编码后的图像矩阵加一维。
- tf.expand_dims() 这个加的维度大家自己要看api去理解
tf.image.sample_distorted_bounding_box(size,boxes) 随机截取图像信息
随机翻转图像,随机调整颜色,随机截图图像中的有信息含量的部分,这些事提高模型健壮性的一种方式。这样可以使是训练得到的模型不受被识别物体大小的影响。
下面贴出完整代码:
# -*- coding: UTF-8 -*import tensorflow as tfimport numpy as npimport matplotlib.pyplot as pltdef distort_color(image,color_ordering=0): if color_ordering == 0: image = tf.image.random_brightness(image,max_delta=32./255.)#亮度 image = tf.image.random_saturation(image,lower=0.5,upper=1.5)#饱和度 image = tf.image.random_hue(image,max_delta=0.2)#色相 image = tf.image.random_contrast(image,lower=0.5,upper=1.5)#对比度 elif color_ordering == 1: image = tf.image.random_brightness(image, max_delta=32. / 255.) # 亮度 image = tf.image.random_hue(image, max_delta=0.2) # 色相 image = tf.image.random_saturation(image, lower=0.5, upper=1.5) # 饱和度 image = tf.image.random_contrast(image, lower=0.5, upper=1.5) # 对比度 return tf.clip_by_value(image,0.0,1.0) #将张量值剪切到指定的最小值和最大值def preprocess_for_train(image,height,width,bbox): #如果没有提供标注框,则认为整个图像就是需要关注的部分 if bbox is None: bbox = tf.constant([0.0,0.0,1.0,1.0],dtype=tf.float32,shape=[1,1,4]) #转换图像张量的类型 if image.dtype != tf.float32: image = tf.image.convert_image_dtype(image,dtype=tf.float32) #随机截取图像,减少需要关注的物体大小对图像识别的影响 bbox_begin,bbox_size,_ = tf.image.sample_distorted_bounding_box(tf.shape(image), bounding_boxes=bbox) distort_image = tf.slice(image,bbox_begin,bbox_size) #将随机截图的图像调整为神经网络输入层的大小。大小调整的算法是随机的 distort_image = tf.image.resize_images( distort_image,[height,width],method=np.random.randint(4) ) #随机左右翻转图像 distort_image = tf.image.random_flip_left_right(distort_image) #使用一种随机的顺序调整图像色彩 distort_image = distort_color(distort_image,np.random.randint(1)) return distort_imageimage_raw_data = tf.gfile.FastGFile("../cat.jpg",'r').read()with tf.Session() as Sess: ima_data = tf.image.decode_jpeg(image_raw_data) boxes = tf.constant([[[0.05,0.05,0.9,0.7],[0.35,0.47,0.5,0.56]]]) #运行6次获得6中不同的图像,在图中显示效果 for i in range(6): #将图像的尺寸调整为299*299 result = preprocess_for_train(ima_data,299,299,boxes) plt.imshow(result.eval()) plt.show()
1 0
- [TensorFlow]入门学习笔记(3)-图像预处理
- TensorFlow学习笔记-图像预处理
- Tensorflow图像预处理函数学习笔记
- TensorFlow学习--tensorflow图像预处理
- TensorFlow学习(十):图像预处理
- TensorFlow图像数据预处理
- tensorflow图像数据预处理
- Tensorflow基础:图像预处理
- Tensorflow-图像预处理
- tensorflow入门学习笔记
- Tensorflow入门学习笔记
- Tensorflow入门学习笔记
- TensorFlow学习笔记:入门
- halcon学习笔记(1) 图像预处理
- Tensorflow中图像的预处理
- TensorFlow学习笔记(二):TensorFlow入门
- TensorFlow学习笔记1:入门
- TensorFlow 学习笔记-入门篇
- JQuery中常用的AJAX方法
- h5day1
- 随机森林特点
- Reverse Integer
- 线性模型
- [TensorFlow]入门学习笔记(3)-图像预处理
- 二维码在线生成
- NOJ_1005
- Graphviz中文处理
- transform
- [leetcode]561. Array Partition I
- 基于glist自定义自己的链表数据结构
- 几种设计模式
- dao层开发代码