cifar10图片可视化(Win7+python3.5)

来源:互联网 发布:java socket心跳检测 编辑:程序博客网 时间:2024/05/12 05:34
  1. 将%caffe_root%\data\cifar10\get_cifar10.sh转换成linux格式,并用cygwin运行(事先装好wget)。自动下载并解压,得到batches.meta.txt,data_batch_1.bin,data_batch_2.bin,data_batch_3.bin,data_batch_4.bin,data_batch_5.bin,test_batch.bin。batches为标签对应的类别名;data_batch_x.bin为第x批10000张的训练图片,共5批;test_batch.bin为10000张测试图片。
  2. 创建%caffe_root%\data\cifar10\cifar10_images\test\data_batch_x(x为1-5共5个文件夹)及%caffe_root%\data\mnist\mnist_images\test_batch文件夹,以保存转换后的图片
  3. 编写如下python脚本,并运行。在对应文件夹下得到图片文件。
cifar10_data_root = r'd:/caffe/data/cifar10/'import structimport numpy as npfrom PIL import Imagedef load_CIFAR_Labels(filename):    with open(filename, 'r') as f:        lines = []        for line in f.readlines():            lines.append(line.strip('\n'))        return linesdef load_CIFAR_batch(filename, label, Pic):    index = 0    bin_file = open(filename, 'rb')    buf = bin_file.read()    for i in range(10000):        (l,) = struct.unpack_from('>b', buf, index)        label.append(l)        index += struct.calcsize('>b')        Pic.append(struct.unpack_from('>3072B', buf, index))        index += struct.calcsize('>3072B')if __name__ == "__main__":    lines = load_CIFAR_Labels(cifar10_data_root + "batches.meta.txt")    print(lines)    NUM_DATA = 10000    ##### TEST DATA SHOW #####    dataPath = cifar10_data_root + r"test_batch.bin"    print("### " + dataPath + " is loading... ###")    label = []    Pic = []    count = [0] * 10    load_CIFAR_batch(dataPath, label, Pic)    for i in range(NUM_DATA):        im = np.array(Pic[i]).reshape(3, 32, 32)        image = Image.new("RGB", (32, 32))        for channel in range(3):            for a in range(32):                for b in range(32):                    image.putpixel((a, b), (im[0][b][a], im[1][b][a], im[2][b][a]))        picName = cifar10_data_root + "cifar10_images/test_batch/test_" + str(i) + "_" + lines[label[i]] + "_" + str(count[label[i]]) + ".png"        count[label[i]] += 1        image.save(picName, "png")        if i % 500 == 0:            print("testSet processed:%d" % i)    ##### TRAIN DATA #####    for j in range(1, 6):        dataPath = cifar10_data_root + r"data_batch_%d.bin"%1        print("### " + dataPath + " is loading... ###")        label = []        Pic = []        count = [0] * 10        load_CIFAR_batch(dataPath, label, Pic)        for i in range(NUM_DATA):            im = np.array(Pic[i]).reshape(3, 32, 32)            image = Image.new("RGB", (32, 32))            for channel in range(3):                for a in range(32):                    for b in range(32):                        image.putpixel((a, b), (im[0][b][a], im[1][b][a], im[2][b][a]))            picName = cifar10_data_root + "cifar10_images/data_batch_%d/train[%d]_"%(j,j) + str(i) + "_" + lines[label[i]] + "_" + str(count[label[i]]) + ".png"            count[label[i]] += 1            image.save(picName, "png")            if i % 500 == 0:                print("trainSet processed:patch(%d) %d" % (j,i))

最后生成文件总占用空间234M。

原创粉丝点击