机器学习-学习笔记 Cifar10(普适物体识别)
来源:互联网 发布:淘宝能赚到钱吗 编辑:程序博客网 时间:2024/05/29 19:23
Cifar10
Cifar10是Caffe自带Demo里的一个数据集,我们先按照之前的MNIST来进行下载数据集,并且进行训练。
下载数据集
有了上次MNIST的例子,这次就很简单啦~
首先,我们需要运行data/cifar10/get_cifar10.sh
./data/cifar10/get_cifar10.sh
运行后会下载一些数据,放在examples下。
我们先看看examples/cifar10下多了什么。
ls ./examples/cifar10
运行create_cifar10.sh
./examples/cifar10/create_cifar10.sh
接着将cifar10_quick_solver.prototxt文件里修改GPU改为CPU。
运行train_quick.sh
./examples/cifar10/train_quick.sh
接着就是漫长的等待了,是不是等不及了呢- -我之前跑了full,跑了二天还没跑完,只跑了5w次迭代,最后我会把这个5w次迭代的快照发出来,今天再跑跑看quick,应该四五个小时就能跑完了。
现在先拿5w的快照跑一下test,感受一下- -
./build/tools/caffe test -model ./examples/cifar10/cifar10_full_train_test.prototxt -weights ./examples/cifar10/cifar10_full_iter_50000.caffemodel.h5 -iterations 20
./build/tools/caffe test -model ./examples/cifar10/cifar10_full_train_test.prototxt -weights ./examples/cifar10/cifar10_full_iter_10000.caffemodel.h5 -iterations 20
可以看到,1w到5w次,成功率只是上升了百分之三,full来说,总共需要跑6w次迭代,接着再跑lr1,lr2,总共19w次迭代,不过77%目前来说还可以,我们先试一试,找个图片,跑一跑。
我们先看一下cifar10的数据集是什么样的。
CIFAR-10
根据上面的描述,说的是一个二进制文件里有1w个彩色图片(每个图片3072个基本单位(numpy ‘s uint8s)),并且标签占一位,在每个图片的开头是一个单位的标签,所以一个图片就占3073个单位,读取代码如下
别忘了先创建cifar10_test文件夹。
mkdir ./cifar10_test
期间需要使用一些类库,可以使用以下代码进行下载。
pip install --upgrade setuptoolspip install numpy Matplotlibpip install opencv-python
下面才是正题,主要就是先提取数据(Label,R, G, B),接着使用merge将三个维度合并,接着使用Iamge.fromarray转换为图片,并进行保存。
import numpy as npimport structimport matplotlib.pyplot as pltimport Imageimport cv2 as cvdef unzip(filename): binfile = open(filename, 'rb') buf = binfile.read() index = 0 numImages = 10000 i = 0 for image in range(0, numImages): label = struct.unpack_from('>1B', buf, index) index += struct.calcsize('>1B') imR = struct.unpack_from('>1024B', buf, index) index += struct.calcsize('>1024B') imR = np.array(imR, dtype='uint8') imR = imR.reshape(32, 32) imG = struct.unpack_from('>1024B', buf, index) imG = np.array(imG, dtype='uint8') imG = imG.reshape(32, 32) index += struct.calcsize('>1024B') imB = struct.unpack_from('>1024B', buf, index) imB = np.array(imB, dtype='uint8') imB = imB.reshape(32, 32) index += struct.calcsize('>1024B') im = cv.merge((imR, imG, imB)) im = Image.fromarray(im) im.save('cifar10_test/train_%s_%s.png' % (label , image), 'png')unzip("./caffe/data/cifar10/data_batch_1.bin")
那具体label对应的是什么,去哪里找呢?
cat ./caffe/data/cifar10/batches.meta.txt
以下是输出结果(一共10个)
airplaneautomobilebirdcatdeerdogfroghorseshiptruck
验证
我们现在得到了图片,也得到了训练好的网络,我们先仿照MNIST进行实验看看。
#coding=utf-8import syssys.path.insert(0, './caffe/python');import numpy as npimport structimport matplotlib.pyplot as pltimport Imageimport cv2 as cvdef unzip(filename): binfile = open(filename, 'rb'); buf = binfile.read(); import caffe index = 0; numImages = 10000; i = 0; n = np.zeros(10); for image in range(0, numImages): label = struct.unpack_from('>1B', buf, index); index += struct.calcsize('>1B'); imR = struct.unpack_from('>1024B', buf, index); index += struct.calcsize('>1024B'); imR = np.array(imR, dtype='uint8'); imR = imR.reshape(32, 32); imG = struct.unpack_from('>1024B', buf, index); imG = np.array(imG, dtype='uint8'); imG = imG.reshape(32, 32); index += struct.calcsize('>1024B'); imB = struct.unpack_from('>1024B', buf, index); imB = np.array(imB, dtype='uint8'); imB = imB.reshape(32, 32); index += struct.calcsize('>1024B'); im = cv.merge((imR, imG, imB)); im = Image.fromarray(im); im.save('cifar10_test/train_%s_%s.png' % (label , n[label]), 'png'); n[label] += 1;def getType(file): img = caffe.io.load_image(file, color=True); net = caffe.Classifier('./caffe/examples/cifar10/cifar10_full.prototxt', './caffe/examples/cifar10/cifar10_full_iter_50000.caffemodel.h5', channel_swap=(2, 1, 0), raw_scale=255,image_dims=(32, 32)); pre = net.predict([img]); caffe.set_mode_cpu(); return pre[0].argmax();#unzip("./caffe/data/cifar10/data_batch_1.bin")import fileinputlabelName = './caffe/data/cifar10/batches.meta.txt';file = open(labelName, 'r');lines = file.read(100000);import relabel = re.split('\n', lines);import caffeimport randomfor i in range(1, 21): fileName = 'cifar10_test/train_(%d,)_%d.0.png' % (random.choice(range(0, 10)), random.choice(range(0, 100))); img = caffe.io.load_image(fileName); plt.subplot(4, 5, i); plt.imshow(img); plt.title(label[getType(fileName)]);plt.show();
结果还是可以的,不过这些都是训练集的数据,我们随便在网上找一张图片试试看。
#coding=utf-8import syssys.path.insert(0, './caffe/python');import numpy as npimport structimport matplotlib.pyplot as pltimport Imageimport cv2 as cvdef unzip(filename): binfile = open(filename, 'rb'); buf = binfile.read(); import caffe index = 0; numImages = 10000; i = 0; n = np.zeros(10); for image in range(0, numImages): label = struct.unpack_from('>1B', buf, index); index += struct.calcsize('>1B'); imR = struct.unpack_from('>1024B', buf, index); index += struct.calcsize('>1024B'); imR = np.array(imR, dtype='uint8'); imR = imR.reshape(32, 32); imG = struct.unpack_from('>1024B', buf, index); imG = np.array(imG, dtype='uint8'); imG = imG.reshape(32, 32); index += struct.calcsize('>1024B'); imB = struct.unpack_from('>1024B', buf, index); imB = np.array(imB, dtype='uint8'); imB = imB.reshape(32, 32); index += struct.calcsize('>1024B'); im = cv.merge((imR, imG, imB)); im = Image.fromarray(im); im.save('cifar10_test/train_%s_%s.png' % (label , n[label]), 'png'); n[label] += 1;def getType(file): img = caffe.io.load_image(file, color=True); net = caffe.Classifier('./caffe/examples/cifar10/cifar10_full.prototxt', './caffe/examples/cifar10/cifar10_full_iter_50000.caffemodel.h5', channel_swap=(2, 1, 0), raw_scale=255,image_dims=(32, 32)); pre = net.predict([img]); caffe.set_mode_cpu(); return pre[0].argmax();def getLabel(labelName): file = open(labelName, 'r'); lines = file.read(100000); import re label = re.split('\n', lines); return label;# 定义缩放resize函数def resize(image, width=None, height=None, inter=cv.INTER_AREA): # 初始化缩放比例,并获取图像尺寸 dim = None (h, w) = image.shape[:2] # 如果宽度和高度均为0,则返回原图 if width is None and height is None: return image # 宽度是0 if width is None: # 则根据高度计算缩放比例 r = height / float(h) dim = (int(w * r), height) # 如果高度为0 else: # 根据宽度计算缩放比例 r = width / float(w) dim = (width, int(h * r)) # 缩放图像 resized = cv.resize(image, dim, interpolation=inter) # 返回缩放后的图像 return resizedimport caffefileName = 'cifar10_test/timg.jpeg';img = caffe.io.load_image(fileName);plt.subplot(1, 2, 1);plt.imshow(img);plt.title('Test Image');img = resize(img, 32, 32);plt.subplot(1, 2, 2);plt.imshow(img);plt.title(getLabel('./caffe/data/cifar10/batches.meta.txt')[getType(fileName)]);plt.show();
- 机器学习-学习笔记 Cifar10(普适物体识别)
- 使用Caffe基于cifar10进行物体识别
- 学习笔记-CIFAR10模型理解简述
- 【caffe学习笔记】windows下跑cifar10
- caffe学习笔记--自定义cifar10网络参数
- 谷歌推出新机器学习API,可识别、搜索视频中物体
- 李飞飞公布新机器学习API:从视频中识别物体
- 机器学习-学习笔记 Caffe安装-MNIST(手写体数字识别)
- 深度学习框架Caffe学习笔记(7)-cifar10例程
- piotr dollar物体识别库 学习一
- caffe学习1--cifar10
- Pytorch学习-CIFAR10分类
- 【caffe学习笔记——cifar10】win10+caffe环境下cifar10运行
- Unity3d 学习笔记(1)-物体
- [机器学习笔记]Note16--应用示例:图像文字识别
- KNN--用于手写数字识别(机器学习入门笔记)
- KNN--用于手写数字识别(机器学习入门笔记)
- 学习笔记:Caffe上配置和运行Cifar10的示例
- 1001. A+B Format (20)
- Android_SQL详解
- Linux日常
- java运行原理
- RBAC用户角色权限设计方案
- 机器学习-学习笔记 Cifar10(普适物体识别)
- Zookeepr和Hadoop,Hbase的关系
- mysql 有则更新无则插入
- PHP 购物车 session (ThinkPHP)
- STL(标准模板库)string(一)
- 自定义圆从屏幕左上角匀速移动到右下角
- java学习系列3(集合hashmap)
- CommandLineRunner详解
- iOS 局部变量 全局变量 成员变量