TensorFLow能够识别的图像文件,可以通过numpy
来源:互联网 发布:js数组字符串方法 编辑:程序博客网 时间:2024/05/29 11:09
TensorFLow能够识别的图像文件,可以通过numpy,使用tf.Variable或者tf.placeholder加载进tensorflow;也可以通过自带函数(tf.read)读取,当图像文件过多时,一般使用pipeline通过队列的方法进行读取。下面我们介绍两种生成tensorflow的图像格式的方法,供给tensorflow的graph的输入与输出。
1
- import cv2
- import numpy as np
- import h5py
-
- height = 460
- width = 345
-
- with h5py.File('make3d_dataset_f460.mat','r') as f:
- images = f['images'][:]
-
- image_num = len(images)
-
- data = np.zeros((image_num, height, width, 3), np.uint8)
- data = images.transpose((0,3,2,1))
先生成图像文件的路径:ls *.jpg> list.txt
- import cv2
- import numpy as np
-
- image_path = './'
- list_file = 'list.txt'
- height = 48
- width = 48
-
- image_name_list = []
- with open(image_path + list_file) as fid:
- image_name_list = [x.strip() for x in fid.readlines()]
- image_num = len(image_name_list)
-
- data = np.zeros((image_num, height, width, 3), np.uint8)
-
- for idx in range(image_num):
- img = cv2.imread(image_name_list[idx])
- img = cv2.resize(img, (height, width))
- data[idx, :, :, :] = img
2 Tensorflow自带函数读取
- def get_image(image_path):
-
-
-
-
-
-
-
- return tf.image.convert_image_dtype(
- tf.image.decode_jpeg(
- tf.read_file(image_path), channels=3),
- dtype=tf.uint8)
pipeline读取方法
-
- import tensorflow as tf
- import random
- from tensorflow.python.framework import ops
- from tensorflow.python.framework import dtypes
-
- dataset_path = "/path/to/your/dataset/mnist/"
- test_labels_file = "test-labels.csv"
- train_labels_file = "train-labels.csv"
-
- test_set_size = 5
-
- IMAGE_HEIGHT = 28
- IMAGE_WIDTH = 28
- NUM_CHANNELS = 3
- BATCH_SIZE = 5
-
- def encode_label(label):
- return int(label)
-
- def read_label_file(file):
- f = open(file, "r")
- filepaths = []
- labels = []
- for line in f:
- filepath, label = line.split(",")
- filepaths.append(filepath)
- labels.append(encode_label(label))
- return filepaths, labels
-
-
- train_filepaths, train_labels = read_label_file(dataset_path + train_labels_file)
- test_filepaths, test_labels = read_label_file(dataset_path + test_labels_file)
-
-
- train_filepaths = [ dataset_path + fp for fp in train_filepaths]
- test_filepaths = [ dataset_path + fp for fp in test_filepaths]
-
-
- all_filepaths = train_filepaths + test_filepaths
- all_labels = train_labels + test_labels
-
- all_filepaths = all_filepaths[:20]
- all_labels = all_labels[:20]
-
-
- all_images = ops.convert_to_tensor(all_filepaths, dtype=dtypes.string)
- all_labels = ops.convert_to_tensor(all_labels, dtype=dtypes.int32)
-
-
- partitions = [0] * len(all_filepaths)
- partitions[:test_set_size] = [1] * test_set_size
- random.shuffle(partitions)
-
-
- train_images, test_images = tf.dynamic_partition(all_images, partitions, 2)
- train_labels, test_labels = tf.dynamic_partition(all_labels, partitions, 2)
-
-
- train_input_queue = tf.train.slice_input_producer(
- [train_images, train_labels],
- shuffle=False)
- test_input_queue = tf.train.slice_input_producer(
- [test_images, test_labels],
- shuffle=False)
-
-
- file_content = tf.read_file(train_input_queue[0])
- train_image = tf.image.decode_jpeg(file_content, channels=NUM_CHANNELS)
- train_label = train_input_queue[1]
-
- file_content = tf.read_file(test_input_queue[0])
- test_image = tf.image.decode_jpeg(file_content, channels=NUM_CHANNELS)
- test_label = test_input_queue[1]
-
-
- train_image.set_shape([IMAGE_HEIGHT, IMAGE_WIDTH, NUM_CHANNELS])
- test_image.set_shape([IMAGE_HEIGHT, IMAGE_WIDTH, NUM_CHANNELS])
-
-
-
- train_image_batch, train_label_batch = tf.train.batch(
- [train_image, train_label],
- batch_size=BATCH_SIZE
-
- )
- test_image_batch, test_label_batch = tf.train.batch(
- [test_image, test_label],
- batch_size=BATCH_SIZE
-
- )
-
- print "input pipeline ready"
-
- with tf.Session() as sess:
-
-
- sess.run(tf.initialize_all_variables())
-
-
- coord = tf.train.Coordinator()
- threads = tf.train.start_queue_runners(coord=coord)
-
- print "from the train set:"
- for i in range(20):
- print sess.run(train_label_batch)
-
- print "from the test set:"
- for i in range(10):
- print sess.run(test_label_batch)
-
-
- coord.request_stop()
- coord.join(threads)
- sess.close()
参考资料
0 0