tensorflow学习之图像处理

来源:互联网 发布:p2p线上平台软件 编辑:程序博客网 时间:2024/06/03 20:31
# -*- coding: utf-8 -*-import  tensorflow as tfimport  os#cifar_10=input_datafrom six.moves import  xrangeIMAGE_SISE=24  ## 原图像的尺度为32*32,但根据常识,信息部分通常位于图像的中央,这里定义了以中心裁剪后图像的尺寸NUM_CLASSES=10NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN=50000NUM_EXAMPLES_PER_EPOCH_FOR_EVAL=10000#读取数据集,并根据数据集使用说明做数据预处理def read_cifar(filename_queue):    class CIFAR10Record(object):        pass    return CIFAR10Record()    label_bytes=1 #2 for cifar_100    result.height=32  #结果中的行数    result.width=32   #结果中的列数    result.depth=3    #结果中的颜色通道数    image_bytes=result.height*result.width*result.depth    record_bytes=label_bytes+image_bytes    reader=tf.FixedLengthRecordReader(record_bytes)    result.key,value=reader.read(filename_queue) #读取一行记录,从filename_queue队列中获取文件名    record_bytes=tf.decode_raw(value,tf.uint8)  #将长度为record_bytes的字符串转换为uint8的向量    result.label==tf.cast(tf.strided_slice(record_bytes,[0],[label_bytes],tf.int32)) # [0]和[label_bytes]分别表示待截取片段的起点和长度 ,转换int32    depth_major=tf.reshape(tf.strided_slice(record_bytes,[label_bytes],[label_bytes+image_bytes]),                           [result.depth,result.height,result.width])    result.uint8image=tf.transpose(depth_major,[1,2,0])  # 将 [depth, height, width] 转换为[height, width, depth]    return  result#创建一个队列的批量图和标签def _generate_image_and_label_batch(image,label,min_queue_examples,batch_size,shuffle):    # 创建一个混排样本的队列,然后从样本队列中读取 'batch_size'数量的 images + labels数据(每个样本都是由images + labels组成)    num_preprocess_threads=16  #预处理采用多线程    if shuffle:        images,label_batch=tf.train.shuffle_batch(            [image,label],            batch_size=batch_size,            num_threads=num_preprocess_threads,            capacity=min_queue_examples+3*batch_size        )    else:        images,label_batch,=tf.train.batch(            [image,label],            batch_size=batch_size,            num_threads=num_preprocess_threads,            capacity=min_queue_examples        )    tf.summary.image('images',images) #训练图像可视化    return  images,tf.reshape(label_batch,[batch_size])#使用Reader操作构建扭曲的输入(图像)用作CIFAR训练def distorted_inputs(data_dir,batch_size):    filenames=[os.path.join(data_dir,'data_batch_%d.bin'%i)               for i in  xrange(1,6)]    for f in  filenames:        if not tf.gfile.Exists(f):            raise ValueError('Faile to fine file:'+f)    filename_queue=tf.train.string_input_producer(filenames)  #创建文件名队列    read_input=read_cifar(filename_queue)                     #从文件名队列中读取样本    reshaped_image=tf.cast(read_input.uint8image,tf,float32)    height=IMAGE_SIZE    #用于训练神经网络的图像处理,    width=IMAGE_SIZE     #对图像进行了很多随机扭曲处理    distorted_image=tf.random_crop(reshaped_image,[height,width,3])  ##随机修建图像的某一块[height,width]区域    distorted_image=tf.image.random_flip_left_right(distorted_image)  #随机水平翻转图像    distorted_image=tf.image.random_brightness(distorted_image,max_delta=63) #随机变换图像的亮度    distorted_image=tf.image.random_contrast(distorted_image,lower=0.2,upper=1.8) #随机变换图像的对比度    float_image=tf.image.per_image_standardization(distorted_image) #对图像进行标准化:减去均值并除以像素的方差    #设置张量的形状    float_image.set_shape([height,width,3])    read_input.label.set_shape([1])    #确保随机混排有很好的混合性    min_fraction_of_examples_queue=0.4    min_queue_examples=int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * min_fraction_of_examples_queue)    print('Filling queue with %d CIFAR images before starting to train.'          'This will take a few minutes.'%min_queue_examples)    #通过构建一个样本队列来生成一批量的图像和标签    return  _generate_image_and_label_batch(float_image,read_input.label,min_queue_examples,batch_size,shuffle=True)#: 使用Reader ops操作构建CIFAR评估的输入def inputs(eval_data,data_dir,batch_size):    if not eval_data:        filenames=[os.path.join(data_dir,'data_batch_%d.bin'% i)                   for i in  xrange(1,6)]        num_examples_per_epoch=NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN    else:        filenames=[os.path.join(data_dir,'test_batch.bin')]        num_examples_per_epoch=NUM_EXAMPLES_PER_EPOCH_FOR_EVAL    for f in  filenames:        if not tf.gfile.Exists(f):            raise ValueError('Failed to find file:'+f)    filename_queue=tf.train.string_input_producer(filenames)#创建一个文件名队列    read_input=read_cifar(filename_queue)    reshaped_image=tf.cast(read_input.uint8image,tf.float32)    height=IMAGE_SISE    width=IMAGE_SISE    resized_image=tf.image.resize_image_with_crop_or_pad(reshaped_image,                                                         height,width)#裁剪图像的中心    float_image=tf.image.per_image_standardization(resized_image)     #标准化:减去均值并除以像素的方差    float_image.set_shape([height,width,3])  #设置张量的形状    read_input.label.set_shape([1])    min_fraction_of_example_in_queue=0.4    min_queue_examples=int(num_examples_per_epoch * min_fraction_of_example_in_queue)    return _generate_image_and_label_batch(float_image,read_input.label,min_queue_examples,batch_size,shuffle=False)

cifar10_input 图像处理