TensorFlow学习(一)
来源:互联网 发布:下电子书的软件 编辑:程序博客网 时间:2024/04/30 09:17
使用TensorFlow的TFRecord来保存和读取数据
在网上看别人写的程序,总觉得已经明白了,没想到自己来写,发现有好多的坑:
- TFRecordWrite/Read不支持中文路径(至少我这里)
- 一定要注意tf.train.Example的层次结构
- Features
- feature1
- {key=string, value=tf.train.Feature()}
- {key, value}
- feature2
- feature1
- Features
- 我使用parse_single_example来解析TFRecordReader后的记录
- features一定要与TFRecordWriter对应
- 在使用string_input_producer来读记录的时候,如果设置了num_epochs的值,一定还要使用local_variables_initializer初始化一次(仅仅用global_variables_initializer会报错)
代码如下
本代码目的,CSV文件中包含JPG文件名和类型,需要按照一定比例分别创建Train, Validation和Test数据集,集合文件为TFRecordFormat,然后用TFRecordReader方式读出来:
“`python
import ioimport sysimport osimport csvimport randomfrom PIL import Imge#filepath = 'D:/文件/数据资料/图像数据/照片/正面分割缩小'#TFRecord不支持中文路径filepath = 'D:\\temp'csvfilename='BL2-CO3.csv'#三种集合的文件名trainfile = 'train.tfrecord'valfile = 'validation.tfrecord'testfile = 'test.tfrecord'imgwidth = 1024imgheight = 384#创建数据集合def CreateData(): #设置各集合的比例 trainratio = 0.7 valratio = 0.15 testratio = 0.15 #读取CSV文件,第一列是文件名,第二列是类型 fileList= [] with open(os.path.join(filepath, csvfilename), newline='') as csvfile: spamreader = csv.reader(csvfile, delimiter=',') isheader = True for row in spamreader: if isheader==False: fileList.append([row[0], int(row[1])]) isheader = False; #打乱顺序 filecount = len(fileList) random.shuffle(fileList) #创建训练、验证和测试集合 index1 = round(filecount * trainratio) trainlist = fileList[0:index1] index2 = round(filecount*valratio) + index1; vallist = fileList[index1:index2] testlist = fileList[index2:] #保存列表到tfrecord文件 def SaveDataSet(datalist, savepath): writer = tf.python_io.TFRecordWriter(savepath) for item in datalist: img = Image.open(item[0]) img.load() #定义label和image两个属性 #使用flatten()将多维图像数据扁平化[height, width, channel] -->[byte] #使用tostring()转换为byte array record = tf.train.Example( features = tf.train.Features( feature = { 'label':tf.train.Feature(int64_list=tf.train.Int64List(value=[item[1]])), 'image':tf.train.Feature(bytes_list= tf.train.BytesList(value=[np.asarray(img, dtype=np.uint8).flatten().tostring()])) } ) ) writer.write(record.SerializeToString()) writer.close() SaveDataSet(trainlist, os.path.join(filepath, trainfile)) SaveDataSet(vallist, os.path.join(filepath, valfile)) SaveDataSet(testlist, os.path.join(filepath, testfile))#读取数据集def TestData(filename, epochs): #读取一条记录 (在线程里面调用) def ReadOneRecord(filename, epochs): #创建文件名队列,这里只有一个文件 filename_queue = tf.train.string_input_producer([filename], num_epochs=epochs) #开始读取记录 reader = tf.TFRecordReader() _,record = reader.read(filename_queue) #解析记录 features = tf.parse_single_example( record, features={ 'label':tf.FixedLenFeature([], tf.int64), 'image':tf.FixedLenFeature([],tf.string) } ) #解析出来的内容 label = features['label'] image = tf.decode_raw(features['image'], tf.uint8) image = tf.reshape(image, [imgheight, imgwidth, 3]) return label, image #获取所有集合数据 def Feech_Data(sess, coord, threads, label_batch, image_batch): try: while not coord.should_stop(): #如果没有取完数据(num_epochs没有结束) label, image = sess.run([label_batch, imag e_batch]) #取数据 print('labelsize=%d, imagesize=%d'%(label.size, image.size)) except tf.errors.OutOfRangeError: #表示num_epoch结束了 print('done training') finally: coord.request_stop() #等待所有线程结束 coord.join(threads) ###获取数据## #引用记录解析过程 label, image = ReadOneRecord(filename, epochs) #TF批次取数据,capacity和min_after_dequeue参数设置比较重要 #网上有不少讲如何设置的,我只是实验,没有讲究 label_batch, image_batch = tf.train.shuffle_batch([label, image], batch_size=5, capacity=30, min_after_dequeue=10, num_threads=2) #获取shuffle_batch创建的线程 #queues = tf.get_collection(tf.GraphKeys.QUEUE_RUS) sess = tf.Session() init = tf.global_variables_initializer() sess.run(init) #一定要运行一下这个初始化,否则num_epochs报错 init = tf.local_variables_initializer() sess.run(init) #控制数据获取线程是否结束 coord = tf.train.Coordinator() #启动数据获取线程 threads = tf.train.start_queue_runners(sess=sess, coord=coord) Feech_Data(sess=sess, coord=coord, threads=threads, label_batch=label_batch, image_batch=image_batch) sess.close()###### 1=创建数据, 0=获取数据######if 0 : CreateData()else: TestData(os.path.join(filepath, trainfile), 50)
阅读全文
0 0
- TensorFlow学习(一)
- 【tensorflow学习】(一)
- TensorFlow学习笔记(一):TensorFlow安装
- TensorFlow学习系列(一):初识TensorFlow
- tensorflow学习笔记(一):tensorflow安装
- TensorFlow学习笔记(一)
- TensorFlow学习(一):感受一下
- tensorflow学习笔记(一)
- TensorFlow学习(一)入门
- TensorFlow学习笔记(一)
- tensorflow学习math_ops(一)
- TensorFlow学习(一):感受一下
- TensorFlow学习(一):感受一下
- TensorFlow学习笔记(一)
- tensorflow学习系列(一)
- Tensorflow学习记录(一)
- Tensorflow学习笔记(一)
- tensorflow学习笔记(一)
- Python装饰器详解
- python---面向对象编程1
- samba服务器的配置
- 离散题目12(判断是否为函数 c++处理)
- JavaScript介绍
- TensorFlow学习(一)
- c# 窗体应用程序 如何添加图片
- Android--onKeyDown方法
- 测试
- C++ 04 —— 构造函数
- nfs服务器的配置
- Storm(二):集群部署配置
- Redis从入门到熟练掌控
- linux上操作mysql数据库