第四篇:对cifar-10数据集的读取

来源:互联网 发布:全球一条线直销软件 编辑:程序博客网 时间:2024/05/16 15:18

介绍:

cifar-10数据集包括50000张训练用的32x32x3的图片和10000张最后测试用的测试集,包括data_batch_1、data_batch_2、data_batch_3、data_batch_4、data_batch_5等,这里展示一个简单的函数读取这个五个batch[10000,3,32,32],然后将五个张量整合到一个矩阵里边[50000,3,32,32]。

测试程序:

# --coding:utf-8 --import tensorflow as tfimport numpy as npimport timeimport matplotlib.pyplot as plt# 读取单个的batch文件def unpickle(datafile):    import cPickle    with open('./cifar-10-batches-py/'+datafile,'rb') as fo:        dict = cPickle.load(fo)    return dictstart_time = time.time()data1 = unpickle('data_batch_1')data2 = unpickle('data_batch_2')data3 = unpickle('data_batch_3')data4 = unpickle('data_batch_4')data5 = unpickle('data_batch_5')# 读取5次data-batch 然后将五个数据整合到一个矩阵X1 = data1['data']label1 = data1['labels']X1 = np.array(X1)np.set_printoptions(threshold='nan')new1 = X1.reshape(-1,3,32,32)X2 = data2['data']label2 = data2['labels']X2 = np.array(X2)np.set_printoptions(threshold='nan')new2 = X2.reshape(-1,3,32,32)X3 = data3['data']label3 = data3['labels']X3 = np.array(X3)np.set_printoptions(threshold='nan')new3 = X3.reshape(-1,3,32,32)X4 = data4['data']label4 = data4['labels']X4 = np.array(X4)np.set_printoptions(threshold='nan')new4 = X4.reshape(-1,3,32,32)X5 = data5['data']label2 = data5['labels']X5 = np.array(X5)np.set_printoptions(threshold='nan')new5 = X5.reshape(-1,3,32,32)X = np.vstack((new1,new2,new3,new4,new5))end_time = time.time() - start_timeprint("end_time:{0:f}\n".format(end_time))# 因为使用imshow将一个矩阵显示为RGB图片,需要# 将三个32*32的矩阵合成一个32*32*3的三维矩阵# 下面就是先将这三个矩阵(32*32)转化为1024*1的向量# 然后使用hstack的功能将每个矩阵上相同位置的值合成# 一个RGB像素点--->[r,g,b]# 最后得到 1024*3的矩阵red   = X[49999][0].reshape(1024,1)green = X[49999][1].reshape(1024,1)blue  = X[49999][2].reshape(1024,1)pic = np.hstack((red,green,blue))# 打印最开始的32*32的矩阵,# 因为为RGB图像,所以为有三个32*32的矩阵# 重新设置pic的形状pic_rgb = pic.reshape(32,32,3)# imshow显示的图片格式应该是# (n,m) or (n,m,3) or (n,m,4)# 显示最后得到的rgb图片plt.imshow(pic_rgb)plt.legend()plt.show()
方法比较简陋,使用的五个X1到X5,来开始存储data_batch_1~5,然后使用np.vstack将五个Xi合并成一个X,最后变成shape为[50000,3,32,32]的张量,到后边应该可以改进读取部分,使用循环,但是不知道循环里代替X(i)编号如何改变。

改进的测试程序:

# --coding:utf-8 --import tensorflow as tfimport numpy as npimport matplotlib.pyplot as plt# 读取单个的batch文件def unpickle(file):    import cPickle    with open('./cifar-10-batches-py/'+file,'rb') as fo:        dict = cPickle.load(fo)    return dictmydata = unpickle('data_batch_1')X = mydata['data']label = mydata['labels']X = np.array(X)np.set_printoptions(threshold='nan')new = X.reshape(10000,3,32,32)# 因为使用imshow将一个矩阵显示为RGB图片,需要# 将三个32*32的矩阵合成一个32*32*3的三维矩阵# 方法一,单个图像转化,可理解、可计算,可以作为方法二的验证 # 下面就是先将这三个矩阵(32*32)转化为1024*1的向量# 然后使用hstack的功能将每个矩阵上相同位置的值合成# 一个RGB像素点--->[r,g,b]# 最后得到 1024*3的矩阵red   = new[1990][0].reshape(1024,1)green = new[1990][1].reshape(1024,1)blue  = new[1990][2].reshape(1024,1)pic = np.hstack((red,green,blue))# 重新设置pic的形状pic_rgb = pic.reshape(32,32,3)# 方法二 使用np.transpose,用法如下,可以使用方法一进行验证两种方法是否一样# 这里的transpose函数里边的参数代表对于# [batch][3][32][32] 对于维度编号是0,1,2,3# 如果转化为[batch][32][32][3],其中0位置不变,把1插到原来0、2之间pic_test = new.transpose((0,2,3,1))# 重新设置pic的形状pic_rgb = pic.reshape(32,32,3)# imshow显示的图片格式应该是# (n,m) or (n,m,3) or (n,m,4)# 显示最后得到的rgb图片plt.imshow(pic_test[1990])plt.legend()plt.show()
        这里使用np.transpose函数,今天刚发现,对于多维度的转置,好用。最后实现的将读取的图片从[batch][3][32][32]转化为[batch][32][32][3],但是不好理解,所以也使用方法一作为验证(方法一好理解,可以计算验证每一个数据到了想要的地方)。





原创粉丝点击