kaggle:code 猫狗识别 图像识别TensorFlow图像预处理

来源:互联网 发布:京都旅游攻略 知乎 编辑:程序博客网 时间:2024/05/21 15:02
#coding:utf8import matplotlib.pyplot as pltimport numpy as npimport osimport tensorflow as tf#import cv2TRAIN_DIR = './wtrain/'TEST_DIR = './wtest/'#提取出前2000的训练集样例#提取出前1000的测试集样例train_image_file_names = [TRAIN_DIR+i for i in os.listdir(TRAIN_DIR)][0:2000]test_image_file_names = [TEST_DIR+i  for i in os.listdir(TEST_DIR)][0:1000]#tensorflow计算图def decode_image(image_file_names,resize_func=None):    images = []    #读取文件,将文件进行预处理成tensorflow可识别格式    graph = tf.Graph()    with graph.as_default():        file_name = tf.placeholder(dtype=tf.string)        file = tf.read_file(file_name)        image = tf.image.decode_jpeg(file)        if resize_func != None:            image = resize_func(image)    #启动会话层,讲转化后的图片结果添加到列表images中    with tf.Session(graph=graph) as session:        tf.global_variables_initializer().run()        for i in range(len(image_file_names)):            images.append(session.run(image,feed_dict={file_name:image_file_names[i]}))            if (i+1)%1000 == 0:                print ('Images proccessed:',i+1)        session.close()    return imagestrain_images = decode_image(train_image_file_names)test_images = decode_image(test_image_file_names)all_images = train_images + test_imageswidth = []height = []aspect_ratio = []for image in all_images:    h,w,d = np.shape(image)    aspect_ratio.append(float(w)/float(h))    width.append(w)    height.append(h)print ("Mean aspect ratio:",np.mean(aspect_ratio))plt.plot(aspect_ratio)plt.show()print ('Mean width : ',np.mean(width))print ('Mean height : ',np.mean(height))plt.plot(width,height,'.r')plt.show()print("Images widther than 500 pixel: ", np.sum(np.array(width) > 500))print("Images higher than 500 pixel: ", np.sum(np.array(height) > 500))del train_imagesdel test_imagesdel all_images#图像转化resizeWIDTH = 500HEIGHT = 500resize_func = lambda image: tf.image.resize_image_with_crop_or_pad(image,HEIGHT,WIDTH)processed_train_images = decode_image(train_image_file_names,resize_func=resize_func)processed_test_images = decode_image(test_image_file_names,resize_func=resize_func)print (np.shape(processed_train_images))print (np.shape(processed_test_images))for i in range(10):    plt.imshow(processed_train_images[i])    plt.show()import pickledef create_batch(data,label,batch_size):    i = 0    while i*batch_size <= len(data):        with open(label+'_'+str(i)+'.pickle','wb') as handle:            content = data[(i*batch_size):((i+1)*batch_size)]            pickle.dump(content,handle)            print ('Saved',label,'part #' + str(i),'with',len(content),'entries.')            i += 1#创建one-hot数据类型labels = [[1.0,0.0] if 'dog' in name else [0.0,1.0] for name in train_image_file_names]
原创粉丝点击