TensorFlow学习(一)

来源:互联网 发布:下电子书的软件 编辑:程序博客网 时间:2024/04/30 09:17

使用TensorFlow的TFRecord来保存和读取数据

在网上看别人写的程序,总觉得已经明白了,没想到自己来写,发现有好多的坑:

  • TFRecordWrite/Read不支持中文路径(至少我这里)
  • 一定要注意tf.train.Example的层次结构
    • Features
      • feature1
        • {key=string, value=tf.train.Feature()}
        • {key, value}
      • feature2
  • 我使用parse_single_example来解析TFRecordReader后的记录
    • features一定要与TFRecordWriter对应
  • 在使用string_input_producer来读记录的时候,如果设置了num_epochs的值,一定还要使用local_variables_initializer初始化一次(仅仅用global_variables_initializer会报错)

代码如下

本代码目的,CSV文件中包含JPG文件名和类型,需要按照一定比例分别创建Train, Validation和Test数据集,集合文件为TFRecordFormat,然后用TFRecordReader方式读出来:

“`python

import ioimport sysimport osimport csvimport randomfrom PIL import Imge#filepath = 'D:/文件/数据资料/图像数据/照片/正面分割缩小'#TFRecord不支持中文路径filepath = 'D:\\temp'csvfilename='BL2-CO3.csv'#三种集合的文件名trainfile = 'train.tfrecord'valfile = 'validation.tfrecord'testfile = 'test.tfrecord'imgwidth = 1024imgheight = 384#创建数据集合def CreateData():    #设置各集合的比例    trainratio = 0.7    valratio = 0.15    testratio = 0.15    #读取CSV文件,第一列是文件名,第二列是类型    fileList= []    with open(os.path.join(filepath, csvfilename), newline='') as csvfile:        spamreader = csv.reader(csvfile, delimiter=',')        isheader = True        for row in spamreader:            if isheader==False:                fileList.append([row[0], int(row[1])])            isheader = False;    #打乱顺序    filecount = len(fileList)    random.shuffle(fileList)    #创建训练、验证和测试集合    index1 = round(filecount * trainratio)    trainlist = fileList[0:index1]    index2 = round(filecount*valratio) + index1;    vallist = fileList[index1:index2]    testlist = fileList[index2:]            #保存列表到tfrecord文件    def SaveDataSet(datalist, savepath):        writer = tf.python_io.TFRecordWriter(savepath)        for item in datalist:           img = Image.open(item[0])           img.load()           #定义label和image两个属性           #使用flatten()将多维图像数据扁平化[height, width, channel] -->[byte]           #使用tostring()转换为byte array           record = tf.train.Example(               features = tf.train.Features(                    feature = {                        'label':tf.train.Feature(int64_list=tf.train.Int64List(value=[item[1]])),                        'image':tf.train.Feature(bytes_list= tf.train.BytesList(value=[np.asarray(img, dtype=np.uint8).flatten().tostring()]))                        }                    )                )            writer.write(record.SerializeToString())        writer.close()    SaveDataSet(trainlist, os.path.join(filepath, trainfile))    SaveDataSet(vallist, os.path.join(filepath, valfile))    SaveDataSet(testlist, os.path.join(filepath, testfile))#读取数据集def TestData(filename, epochs):    #读取一条记录 (在线程里面调用)    def ReadOneRecord(filename, epochs):        #创建文件名队列,这里只有一个文件        filename_queue = tf.train.string_input_producer([filename], num_epochs=epochs)          #开始读取记录        reader = tf.TFRecordReader()        _,record = reader.read(filename_queue)    #解析记录        features = tf.parse_single_example(            record,            features={                'label':tf.FixedLenFeature([], tf.int64),                'image':tf.FixedLenFeature([],tf.string)                }            )        #解析出来的内容        label = features['label']        image = tf.decode_raw(features['image'], tf.uint8)        image = tf.reshape(image, [imgheight, imgwidth, 3])        return label, image    #获取所有集合数据    def Feech_Data(sess, coord, threads, label_batch, image_batch):        try:            while not coord.should_stop():                #如果没有取完数据(num_epochs没有结束)                label, image = sess.run([label_batch, imag  e_batch])                #取数据                print('labelsize=%d, imagesize=%d'%(label.size, image.size))        except tf.errors.OutOfRangeError:  #表示num_epoch结束了            print('done training')        finally:            coord.request_stop()        #等待所有线程结束        coord.join(threads)    ###获取数据##    #引用记录解析过程    label, image = ReadOneRecord(filename, epochs)    #TF批次取数据,capacity和min_after_dequeue参数设置比较重要    #网上有不少讲如何设置的,我只是实验,没有讲究    label_batch, image_batch = tf.train.shuffle_batch([label, image], batch_size=5, capacity=30, min_after_dequeue=10, num_threads=2)    #获取shuffle_batch创建的线程    #queues = tf.get_collection(tf.GraphKeys.QUEUE_RUS)    sess = tf.Session()    init = tf.global_variables_initializer()    sess.run(init)    #一定要运行一下这个初始化,否则num_epochs报错    init = tf.local_variables_initializer()    sess.run(init)    #控制数据获取线程是否结束           coord = tf.train.Coordinator()    #启动数据获取线程    threads = tf.train.start_queue_runners(sess=sess, coord=coord)    Feech_Data(sess=sess, coord=coord, threads=threads, label_batch=label_batch, image_batch=image_batch)    sess.close()###### 1=创建数据, 0=获取数据######if 0 :    CreateData()else:    TestData(os.path.join(filepath, trainfile), 50)