机器学习-学习笔记 Cifar10(普适物体识别)

来源:互联网 发布:淘宝能赚到钱吗 编辑:程序博客网 时间:2024/05/29 19:23

Cifar10

Cifar10是Caffe自带Demo里的一个数据集,我们先按照之前的MNIST来进行下载数据集,并且进行训练。

下载数据集

有了上次MNIST的例子,这次就很简单啦~
首先,我们需要运行data/cifar10/get_cifar10.sh

./data/cifar10/get_cifar10.sh

运行后会下载一些数据,放在examples下。

我们先看看examples/cifar10下多了什么。

ls ./examples/cifar10

运行create_cifar10.sh

./examples/cifar10/create_cifar10.sh

接着将cifar10_quick_solver.prototxt文件里修改GPU改为CPU。

运行train_quick.sh

./examples/cifar10/train_quick.sh

接着就是漫长的等待了,是不是等不及了呢- -我之前跑了full,跑了二天还没跑完,只跑了5w次迭代,最后我会把这个5w次迭代的快照发出来,今天再跑跑看quick,应该四五个小时就能跑完了。

现在先拿5w的快照跑一下test,感受一下- -

./build/tools/caffe test -model ./examples/cifar10/cifar10_full_train_test.prototxt -weights ./examples/cifar10/cifar10_full_iter_50000.caffemodel.h5 -iterations 20

这里写图片描述

./build/tools/caffe test -model ./examples/cifar10/cifar10_full_train_test.prototxt -weights ./examples/cifar10/cifar10_full_iter_10000.caffemodel.h5 -iterations 20

这里写图片描述

可以看到,1w到5w次,成功率只是上升了百分之三,full来说,总共需要跑6w次迭代,接着再跑lr1,lr2,总共19w次迭代,不过77%目前来说还可以,我们先试一试,找个图片,跑一跑。

我们先看一下cifar10的数据集是什么样的。

CIFAR-10

这里写图片描述

根据上面的描述,说的是一个二进制文件里有1w个彩色图片(每个图片3072个基本单位(numpy ‘s uint8s)),并且标签占一位,在每个图片的开头是一个单位的标签,所以一个图片就占3073个单位,读取代码如下

别忘了先创建cifar10_test文件夹。

mkdir ./cifar10_test

期间需要使用一些类库,可以使用以下代码进行下载。

pip install --upgrade setuptoolspip install numpy Matplotlibpip install opencv-python

下面才是正题,主要就是先提取数据(Label,R, G, B),接着使用merge将三个维度合并,接着使用Iamge.fromarray转换为图片,并进行保存。

import numpy as npimport structimport matplotlib.pyplot as pltimport Imageimport cv2 as cvdef unzip(filename):    binfile = open(filename, 'rb')    buf = binfile.read()    index = 0    numImages = 10000    i = 0    for image in range(0, numImages):        label = struct.unpack_from('>1B', buf, index)        index += struct.calcsize('>1B')        imR = struct.unpack_from('>1024B', buf, index)        index += struct.calcsize('>1024B')        imR = np.array(imR, dtype='uint8')        imR = imR.reshape(32, 32)        imG = struct.unpack_from('>1024B', buf, index)        imG = np.array(imG, dtype='uint8')        imG = imG.reshape(32, 32)        index += struct.calcsize('>1024B')        imB = struct.unpack_from('>1024B', buf, index)        imB = np.array(imB, dtype='uint8')        imB = imB.reshape(32, 32)        index += struct.calcsize('>1024B')        im = cv.merge((imR, imG, imB))        im = Image.fromarray(im)        im.save('cifar10_test/train_%s_%s.png' % (label , image), 'png')unzip("./caffe/data/cifar10/data_batch_1.bin")

这里写图片描述

那具体label对应的是什么,去哪里找呢?

cat ./caffe/data/cifar10/batches.meta.txt

以下是输出结果(一共10个)

airplaneautomobilebirdcatdeerdogfroghorseshiptruck

验证

我们现在得到了图片,也得到了训练好的网络,我们先仿照MNIST进行实验看看。

#coding=utf-8import syssys.path.insert(0, './caffe/python');import numpy as npimport structimport matplotlib.pyplot as pltimport Imageimport cv2 as cvdef unzip(filename):    binfile = open(filename, 'rb');    buf = binfile.read();    import caffe    index = 0;    numImages = 10000;    i = 0;    n = np.zeros(10);    for image in range(0, numImages):        label = struct.unpack_from('>1B', buf, index);        index += struct.calcsize('>1B');        imR = struct.unpack_from('>1024B', buf, index);        index += struct.calcsize('>1024B');        imR = np.array(imR, dtype='uint8');        imR = imR.reshape(32, 32);        imG = struct.unpack_from('>1024B', buf, index);        imG = np.array(imG, dtype='uint8');        imG = imG.reshape(32, 32);        index += struct.calcsize('>1024B');        imB = struct.unpack_from('>1024B', buf, index);        imB = np.array(imB, dtype='uint8');        imB = imB.reshape(32, 32);        index += struct.calcsize('>1024B');        im = cv.merge((imR, imG, imB));        im = Image.fromarray(im);        im.save('cifar10_test/train_%s_%s.png' % (label , n[label]), 'png');        n[label] += 1;def getType(file):    img = caffe.io.load_image(file, color=True);    net = caffe.Classifier('./caffe/examples/cifar10/cifar10_full.prototxt',     './caffe/examples/cifar10/cifar10_full_iter_50000.caffemodel.h5', channel_swap=(2, 1, 0),     raw_scale=255,image_dims=(32, 32));    pre = net.predict([img]);    caffe.set_mode_cpu();    return pre[0].argmax();#unzip("./caffe/data/cifar10/data_batch_1.bin")import fileinputlabelName = './caffe/data/cifar10/batches.meta.txt';file = open(labelName, 'r');lines = file.read(100000);import relabel = re.split('\n', lines);import caffeimport randomfor i in range(1, 21):    fileName = 'cifar10_test/train_(%d,)_%d.0.png' % (random.choice(range(0, 10)), random.choice(range(0, 100)));    img = caffe.io.load_image(fileName);    plt.subplot(4, 5, i);    plt.imshow(img);    plt.title(label[getType(fileName)]);plt.show();

这里写图片描述

结果还是可以的,不过这些都是训练集的数据,我们随便在网上找一张图片试试看。

#coding=utf-8import syssys.path.insert(0, './caffe/python');import numpy as npimport structimport matplotlib.pyplot as pltimport Imageimport cv2 as cvdef unzip(filename):    binfile = open(filename, 'rb');    buf = binfile.read();    import caffe    index = 0;    numImages = 10000;    i = 0;    n = np.zeros(10);    for image in range(0, numImages):        label = struct.unpack_from('>1B', buf, index);        index += struct.calcsize('>1B');        imR = struct.unpack_from('>1024B', buf, index);        index += struct.calcsize('>1024B');        imR = np.array(imR, dtype='uint8');        imR = imR.reshape(32, 32);        imG = struct.unpack_from('>1024B', buf, index);        imG = np.array(imG, dtype='uint8');        imG = imG.reshape(32, 32);        index += struct.calcsize('>1024B');        imB = struct.unpack_from('>1024B', buf, index);        imB = np.array(imB, dtype='uint8');        imB = imB.reshape(32, 32);        index += struct.calcsize('>1024B');        im = cv.merge((imR, imG, imB));        im = Image.fromarray(im);        im.save('cifar10_test/train_%s_%s.png' % (label , n[label]), 'png');        n[label] += 1;def getType(file):    img = caffe.io.load_image(file, color=True);    net = caffe.Classifier('./caffe/examples/cifar10/cifar10_full.prototxt',     './caffe/examples/cifar10/cifar10_full_iter_50000.caffemodel.h5', channel_swap=(2, 1, 0),     raw_scale=255,image_dims=(32, 32));    pre = net.predict([img]);    caffe.set_mode_cpu();    return pre[0].argmax();def getLabel(labelName):    file = open(labelName, 'r');    lines = file.read(100000);    import re    label = re.split('\n', lines);    return label;# 定义缩放resize函数def resize(image, width=None, height=None, inter=cv.INTER_AREA):    # 初始化缩放比例,并获取图像尺寸    dim = None    (h, w) = image.shape[:2]    # 如果宽度和高度均为0,则返回原图    if width is None and height is None:        return image    # 宽度是0    if width is None:        # 则根据高度计算缩放比例        r = height / float(h)        dim = (int(w * r), height)    # 如果高度为0    else:        # 根据宽度计算缩放比例        r = width / float(w)        dim = (width, int(h * r))    # 缩放图像    resized = cv.resize(image, dim, interpolation=inter)    # 返回缩放后的图像    return resizedimport caffefileName = 'cifar10_test/timg.jpeg';img = caffe.io.load_image(fileName);plt.subplot(1, 2, 1);plt.imshow(img);plt.title('Test Image');img = resize(img, 32, 32);plt.subplot(1, 2, 2);plt.imshow(img);plt.title(getLabel('./caffe/data/cifar10/batches.meta.txt')[getType(fileName)]);plt.show();

这里写图片描述

原创粉丝点击