caffe2 学习笔记03-从图片如何到mdb数据集
来源:互联网 发布:mysql trigger 判断 编辑:程序博客网 时间:2024/06/07 04:08
caffe2 学习笔记03-从图片如何到mdb数据集
- caffe2 学习笔记03-从图片如何到mdb数据集
- 前言
- import库文件
- 准备
- write函数读入图片文件与标签并转换为mdb文件
- read函数读取mdb文件并校验此步不是必须的
- 执行
- 可能遇见的报错
- CHW和HWC的问题
- channels不匹配问题
1. 前言
本文以caffe2训练识别汉字模型为例;
2. import库文件
输出为Required modules imported.")
即导入成功,若提示缺少某个库文件,请谷歌一下;
# -*- coding: UTF-8 -*-%matplotlib inlineimport osimport skimageimport skimage.io as ioimport skimage.transformimport sys import numpy as npimport mathfrom matplotlib import pyplotimport matplotlib.image as mpimgfrom __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionfrom __future__ import unicode_literalsimport argparseimport numpy as npimport lmdbfrom caffe2.proto import caffe2_pb2from caffe2.python import workspace, model_helperprint("Required modules imported.")
3. 准备
设置路径,设置标签对应表,限制最大输出文件大小
path = "/home/hw/H/00_dataOfPlate/15_hanzi/01_new_chn/train" #数据路径sep = os.path.sep #当前系统(linux)路径分隔符chn = ["beijing", "tianjin", "hebei", "shanxi", "neimenggu", "liaoning", "jilin", "heilongjiang", "shanghai", "jiangsu", "zhejiang", "anhui", "fujian", "jiangxi", "shandong", "henan", "hubei", "hunan", "guangdong", "guangxi", "hainan", "sichuan", "guizhou", "yunnan", "chongqin", "xizang", "shengxi", "gansu", "qinghai", "ningxia", "xinjiang"] # No. 31 # 标签对应表LMDB_MAP_SIZE = 1099511627776 #max output file < 1TBprint("prepared") 最大输出文件大小
4. write函数:读入图片文件与标签,并转换为mdb文件
文件结构,以train文件夹为例,train下包含26个字母,标签label以图片所在文件夹为准;
- train
- A
- 0001.bmp
- 0002.bmp
- …
- 4000.bmp
- B
- 0001.bmp
- …
- 4100.bmp
- …
- …
def write_db_with_caffe2(output_file): print(">>> Write database ...") LMDB_MAP_SIZE = 1099511627776 env = lmdb.open(output_file, map_size = LMDB_MAP_SIZE) checksum = 0 checksumm = 0 j = 0 with env.begin(write = True) as txn: for dirs in os.listdir(path): # print dirs new_path = path + sep + dirs label = chn.index(dirs) for pics in os.listdir(new_path): #print pics # print(len(os.listdir(new_path))) pic_path = new_path + sep + pics #print pic_path img_data = skimage.img_as_float(skimage.io.imread(pic_path)).astype(np.float) print("before: {}".format(img_data.shape)) img_data = img_data[:,:,:1] #3通道转换为1通道 img_data = img_data.swapaxes(1, 2).swapaxes(0, 1) #HWC 转换为 CHW print("after: {}".format(img_data.shape)) # print np.prod(img_data.shape) tensor_protos = caffe2_pb2.TensorProtos() img_tensor = tensor_protos.protos.add() img_tensor.dims.extend(img_data.shape) img_tensor.data_type = 1 flatten_img = img_data.reshape(np.prod(img_data.shape)) print("after: {}".format(flatten_img.shape)) img_tensor.float_data.extend(flatten_img.flat) label_tensor = tensor_protos.protos.add() label_tensor.data_type = 2 label_tensor.int32_data.append(label) txn.put('{}'.format(j).encode('ascii'),tensor_protos.SerializeToString()) # print(np.sum(img_data)) # print(label) checksum += np.sum(img_data) * label checksumm += np.sum(img_data) if(j % 5 == 0): pass # print("Inserted {} rows".format(j)) j+=1 # print(j) print("Checksum/write: {}".format(int(checksum))) print("Checksumm/write: {}".format(int(checksumm)))
5. read函数:读取mdb文件,并校验(此步不是必须的)
输入数据所在文件夹:read_db_with_caffe2(db_file, expected_checksum)
db_file: 数据文件所在路径
expected_checksum:期望的输出校验值,应该与write_db_with_caffe2中的值对应
def read_db_with_caffe2(db_file, expected_checksum): print(">>> Read database...") model = model_helper.ModelHelper(name="lmdbtest") batch_size = 744000 #共计多少个文件,一定要写正确,否则会造成校验失败("Read/write checksums dont match") data, label = model.TensorProtosDBInput( [], ["data", "label"], batch_size=batch_size, db=db_file, db_type="lmdb") checksum = 0 workspace.RunNetOnce(model.param_init_net) workspace.CreateNet(model.net) for _ in range(0, 1): workspace.RunNet(model.net.Proto().name) img_datas = workspace.FetchBlob("data") labels = workspace.FetchBlob("label")# print("batch_size: {}".format(batch_size))# print(img_data.shape) for j in range(batch_size):# print(img_datas[j, 2]) checksum += np.sum(img_datas[j, :]) * labels[j] checksumm += np.sum(img_datas[j, :])# print(np.sum(img_datas[j,:]))# print(labels[j]) print("Checksum/read: {}".format(int(checksum))) print("minus of read and write: {}".format(np.abs(expected_checksum - checksum ))) assert np.abs(expected_checksum - checksum < 0.1), \ "Read/write checksums dont match"
6. 执行
执行时间较长,请耐心等待,读取744000个大小为20*20的灰度图像时,时间约为二十分钟,读取db数据进行测试,电脑卡死了,o(╯□╰)o;
write_db_with_caffe2("./chn_db") read_db_with_caffe2("./chn_db", 640020532) #640020532为校验值,应该等于write中输出的checksum大小
7. 可能遇见的报错
1. CHW和HWC的问题:
input channels does not match: # of input channels 20 is not equal to kernel channels * group:1*1
原因:默认读取的图片为shape为HWC(height/width/channels),而caffe2默认图片数据格式为CHW,所以需要进行转换,不转换则报错如下:
RuntimeError: [enforce fail at conv_op_impl.h:30] C == filter.dim32(1) * group_. Convolution op: input channels does not match: # of input channels 20 is not equal to kernel channels * group:1*1 Error from operator: input: "data" input: "conv1_w" input: "conv1_b" output: "conv1" name: "" type: "Conv" arg { name: "kernel" i: 5 } arg { name: "exhaustive_search" i: 0 } arg { name: "order" s: "NCHW" } engine: "CUDNN"
解决方式:在将图片转换为mdb文件时,加入img_data = img_data.swapaxes(1, 2).swapaxes(0, 1)
(上面的程序中已经加入了)
2. channels不匹配问题:
input channels does not match: # of input channels 3 is not equal to kernel channels * group:1*1
原因:默认读取的图片,不论是否为灰度图,都会以三通道的形式读取,经过上面1. 中的HWC–>> CHW的转换后,通道为3,与MNIST示例LENET中的单通道不匹配,所以报错如下:
RuntimeError: [enforce fail at conv_op_impl.h:30] C == filter.dim32(1) * group_. Convolution op: input channels does not match: # of input channels 3 is not equal to kernel channels * group:1*1 Error from operator:
input: "data" input: "conv1_w" input: "conv1_b" output: "conv1" name: "" type: "Conv" arg { name: "kernel" i: 5 } arg { name: "exhaustive_search" i: 0 } arg { name: "order" s: "NCHW" } engine: "CUDNN"
解决方式:在将图片转换为mdb文件时,加入:img_data = img_data[:,:,:1]
(上面程序已经加入了)
- caffe2 学习笔记03-从图片如何到mdb数据集
- 深度学习框架-从TensorFlow 到 Caffe2:盘点深度学习框架
- c#学习笔记(二):保存图片、保存DataGridView数据到本地和从本地读取到DataGridView
- 从oracle中导出数据到access中,以mdb格式导出文件
- Nilearn学习笔记2-从FMRI数据到时间序列
- 如何将pictureBox里的图片保存到数据库,然后从数据读取显示
- 如何从UIImage或者CGImage获取到图片的像素数据
- caffe2 学习笔记04-训练网络并进行预测
- 从如何解决问题到如何学习算法
- 从如何解决问题到如何学习算法
- 从如何解决问题到如何学习算法
- 从如何解决问题到如何学习算法
- 如何将access mdb数据库导入到mysql中 .mdb转mysql
- Android异步加载学习笔记之二:实现ListView中的图片数据从网络加载
- C#学习笔记—程序集-从基础到高级
- caffe2 二 : Models and Datasets 模型和数据集
- Android学习笔记_12_网络通信之从web获取资源数据到Android
- kaldi学习笔记:run.sh(egs/timit/s5)详细分析:从数据准备到特征提取
- 裴波那契数列的递归实现与非递归实现
- Halloween Costumes 区间dp
- 算法基础-使用循环不变式解决插入排序问题
- java-Socket通信,同时进行Json数据的传递
- Hadoop基础教程-第14章 大数据面试笔试题汇总(持续更新)
- caffe2 学习笔记03-从图片如何到mdb数据集
- tomcat部署项目,并修改端口号
- Python字符串逆序输出
- java中byte转int时候为什么要&0xff
- React Native Reducer结构较深与render map展示刷新问题
- RDLC——最简单实例
- Axure 按钮多事件触发
- Python PIL图片添加字体
- lucene5--多索引目录查询以及多线程查询