Tensorflow图片数据读取
来源:互联网 发布:离线地图 知乎 编辑:程序博客网 时间:2024/05/29 17:43
关于Tensorflow读取数据,有三种方法:
- 供给数据(feeding):在Tensorflow程序运行的每一步,让Python代码来供给数据。
- 从文件中读取数据:在Tensorflow图的开始,让一个输入管线从文件中读取数据。
- 预加载数据:在Tensorflow图中定义常量或者变量来保存所有的数据(仅适用于数据量比较小的情况)
对于数据量较小而言,一般选择直接将数据加载进内存,然后再分batch输入网络进行训练;但是,如果数据量较大,最好的方法就是使用Tensorflow提供的队列queue,也就是第二种方法从文件中读取数据,对于一些特定格式的读取,比如cs文件格式。另外还有一种比较官方的通用Tensorflow内定的标准格式—-TFRecords。
从csv文件中读取数据
csv文件一般一行是一个样本,包括样本值和label,用逗号隔开,读取的时候可以用tf.TextLlineReader类来每次读取一行,并使用tf.decode_csv对每一行解析。
tf.decode_csv(records,record_dafaults,field_delim=None,name=None)
- records为reader独到的内容,此处为csv文件的一行。
- record_defaults为指定分割后的每个参数的类型,比如分割后有三列,那么第二个参数就应该是[[‘int32’],[],[‘tring’]],不指定类型(设为空也可以),如果分割后的属性比较多,可以用[[]*100]来表示。
field_delim是指定用什么来分割,默认为逗号。
col=tf.decode_csv(records, record_defaults=[[]*100],field_delim=‘ ’, name=None)
返回的col是长度为100的list。
需要注意的是,当数据量比较大的时候,存成CSV或TXT文件要比BIN文件大的多,因此在TF中读取的速度也会慢很多。因此尽量不要读取大的CSV的方式来输入。
此处选择鸢尾花进行实例化
官方采用tensorflow内置函数读取数据 的代码
#coding:utf-8from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport tensorflow as tfimport numpy as np# 设定数据集的位置IRIS_TRAINING = "iris_training.csv"IRIS_TEST = "iris_test.csv"# 使用Tensorflow内置的方法进行数据加载training_set = tf.contrib.learn.datasets.base.load_csv_with_header( filename=IRIS_TRAINING, target_dtype=np.int, features_dtype=np.float32)test_set = tf.contrib.learn.datasets.base.load_csv_with_header( filename=IRIS_TEST, target_dtype=np.int, features_dtype=np.float32)# 每行数据4个特征,都是real-value的feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]# 构建一个DNN分类器,3层,其中每个隐含层的节点数量分别为10,20,10,目标的分类3个,并且指定了保存位置classifier =tf.contrib.learn.DNNClassifier(feature_columns=feature_columns, hidden_units=[10, 20, 10], n_classes=3, model_dir="/tmp/iris_model")# 指定数据,以及训练的步数classifier.fit(x=training_set.data, y=training_set.target, steps=2000)# 模型评估accuracy_score = classifier.evaluate(x=test_set.data, y=test_set.target)["accuracy"]print('Accuracy: {0:f}'.format(accuracy_score))# 直接创建数据来进行预测new_samples = np.array( [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)y = list(classifier.predict(new_samples))print('Predictions: {}'.format(str(y)))
自己定义的读取数据的形式的代
import tensorflow as tfimport osdef read_data(file_queue): reader = tf.TextLineReader(skip_header_lines=1)#从第二行开始读 key, value = reader.read(file_queue) #print(key,values) defaults = [[0], [0.], [0.], [0.], [0.], ['']] Id,SepalLengthCm,SepalWidthCm,PetalLengthCm,PetalWidthCm,Species = tf.decode_csv(value, defaults) #数据处理 preprocess_op = tf.case({ tf.equal(Species, tf.constant('Iris-setosa')): lambda: tf.constant(0), tf.equal(Species, tf.constant('Iris-versicolor')): lambda: tf.constant(1), tf.equal(Species, tf.constant('Iris-virginica')): lambda: tf.constant(2), }, lambda: tf.constant(-1), exclusive=True) #栈 return tf.stack([SepalLengthCm,SepalWidthCm,PetalLengthCm,PetalWidthCm]), preprocess_opdef create_pipeline(filename, batch_size, num_epochs=None): file_queue = tf.train.string_input_producer([filename], num_epochs=num_epochs) example, label = read_data(file_queue) #tf.train.string_input_producer:创建一个线程 min_after_dequeue = 1000 capacity = min_after_dequeue + batch_size example_batch, label_batch = tf.train.shuffle_batch( [example, label], batch_size=batch_size, capacity=capacity, min_after_dequeue=min_after_dequeue )#tf.train.shuffle_batch:对队列中的样本进行乱序处理,返回的是从队列中随机筛选多个样例返回给image_batch,label_batch, return example_batch, label_batchx_train_batch, y_train_batch = create_pipeline('Iris-train.csv', 50, num_epochs=1000)x_test, y_test = create_pipeline('Iris-test.csv', 60)global_step = tf.Variable(0, trainable=False)learning_rate = 0.1#tf.train.exponential_decay(0.1, global_step, 100, 0.0)# Input layerx = tf.placeholder(tf.float32, [None, 4])y = tf.placeholder(tf.int32, [None])# Output layerw = tf.Variable(tf.random_normal([4, 3]))b = tf.Variable(tf.random_normal([3]))a = tf.matmul(x, w) + bprediction = tf.nn.softmax(a)# Trainingcross_entropy = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=a, labels=y))tf.summary.scalar('Cross_Entropy', cross_entropy)train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(cross_entropy, global_step=global_step)correct_prediction = tf.equal(tf.argmax(prediction,1), tf.cast(y, tf.int64))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))tf.summary.scalar('Accuracy', accuracy)init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())merged_summary = tf.summary.merge_all()sess = tf.Session()train_writer = tf.summary.FileWriter('logs/train', sess.graph)test_writer = tf.summary.FileWriter('logs/test', sess.graph)sess.run(init)coord = tf.train.Coordinator()#线程的协调器threads = tf.train.start_queue_runners(sess=sess, coord=coord)#启动输入管道的线程,填充样本到队列中以便出队操作可以从队列中拿到样本。try: print("Training: ") count = 0 #curr_x_test_batch, curr_y_test_batch = sess.run([x_test, y_test]) while not coord.should_stop(): # Run training steps or whatever curr_x_train_batch, curr_y_train_batch = sess.run([x_train_batch, y_train_batch]) sess.run(train_step, feed_dict={ x: curr_x_train_batch, y: curr_y_train_batch }) count += 1 ce, summary = sess.run([cross_entropy, merged_summary], feed_dict={ x: curr_x_train_batch, y: curr_y_train_batch }) train_writer.add_summary(summary, count) curr_x_test_batch, curr_y_test_batch = sess.run([x_test, y_test]) ce, test_acc, test_summary = sess.run([cross_entropy, accuracy, merged_summary], feed_dict={ x: curr_x_test_batch, y: curr_y_test_batch }) test_writer.add_summary(summary, count) print('Batch', count, 'J = ', ce,'测试准确率=',test_acc)except tf.errors.OutOfRangeError: print('Done training -- epoch limit reached')finally: # When done, ask the threads to stop. coord.request_stop() #通知其他线程关闭# Wait for threads to finish.coord.join(threads)#其他所有线程关闭后,此函数返回sess.close()
TFRcords文件
我们使用tf.train.Example来定义我们要填入的数据格式,然后使用tf.Python_io.TFRecordWriter来写入 。
一旦生成了TFRecords文件,接下来就可以使用队列(queue)读取数据了。
采用猫狗的图片作为例子进实例化,如何读取图片的数据。
import tensorflow as tfimport osfrom PIL import Imagecwd=os.getcwd()#print(cwd)classes=['cats','dogs']writer=tf.python_io.TFRecordWriter('train.tfrecords')for index,name in enumerate(classes): class_path= cwd+ '\\'+ name +'\\' for img_name in os.listdir(class_path): img_path=class_path+img_name img =Image.open(img_path) img=img.resize((224,224)) img_raw=img.tobytes() example=tf.train.Example(features=tf.train.Features(feature={ 'label':tf.train.Feature(int64_list=tf.train.Int64List(value=[index])), 'img_raw':tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])) }))writer.write(example.SerializeToString())writer.close()for serialized_example in tf.python_io.tf_record_iterator("train.tfrecords"): example = tf.train.Example() example.ParseFromString(serialized_example) image = example.features.feature['img_raw'].bytes_list.value label = example.features.feature['label'].int64_list.value print(image) print(label)def read_and_decode(filename): #根据文件名生成一个队列 filename_queue = tf.train.string_input_producer([filename]) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) #返回文件名和文件 features = tf.parse_single_example(serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'img_raw' : tf.FixedLenFeature([], tf.string), }) img = tf.decode_raw(features['img_raw'], tf.uint8) img = tf.reshape(img, [224, 224, 3]) img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 label = tf.cast(features['label'], tf.int32) return img, labelimg, label = read_and_decode("train.tfrecords")#使用shuffle_batch可以随机打乱输入img_batch, label_batch = tf.train.shuffle_batch([img, label],batch_size=30, capacity=2000,min_after_dequeue=1000)init = tf.global_variables_initializer()with tf.Session() as sess: sess.run(init) threads = tf.train.start_queue_runners(sess=sess) for i in range(3): val, l= sess.run([img_batch, label_batch]) #我们也可以根据需要对val, l进行处理 #l = to_categorical(l, 12) print(val.shape, l)
总结:
- 生成TFRecord文件
- 定义record reader 解析tfrecord文件
- 构造一个批生成器
- 构建其他的操作
- 初始化所有的操作
- 启动QueueRunner
阅读全文
0 0
- Tensorflow图片数据读取
- Tensorflow图片数据读取
- tensorflow图片数据读取
- tensorflow 读取图片
- tensorflow io 图片读取
- tensorflow爬坑行:数据读取
- Tensorflow读取数据
- Tensorflow读取数据
- tensorflow读取文件数据
- Tensorflow读取数据1
- Tensorflow数据读取方式
- Tensorflow数据读取方法
- TensorFlow读取tfrecords数据
- tensorflow读取数据
- TensorFlow数据读取
- tensorflow读取数据
- tensorflow 数据读取笔记
- TensorFlow数据读取
- rsync 批处理,忽略文件
- AndroidUtils:Android开发不得不收藏的Utils
- SDUVJ开发实录(五):Problem等界面的显示优化
- 全卷积网络(FCN)与图像分割
- python对json的操作
- Tensorflow图片数据读取
- JVM初窥:Java对象的内存结构
- 3.Javascript语法语句
- VIM和Python编码转换原理图
- python中的网络编程
- 软件工程学习(2)
- intellijidea设置"向前"和"向后"快捷键
- MySQL基础教程2-创建表和列操作
- MyEclipse 清理项目缓存的几大方法