caffe2 学习笔记03-从图片如何到mdb数据集

来源:互联网 发布:mysql trigger 判断 编辑:程序博客网 时间:2024/06/07 04:08

caffe2 学习笔记03-从图片如何到mdb数据集

  • caffe2 学习笔记03-从图片如何到mdb数据集
    • 前言
    • import库文件
    • 准备
    • write函数读入图片文件与标签并转换为mdb文件
    • read函数读取mdb文件并校验此步不是必须的
    • 执行
    • 可能遇见的报错
      • CHW和HWC的问题
      • channels不匹配问题

1. 前言

本文以caffe2训练识别汉字模型为例;

2. import库文件

输出为Required modules imported.")即导入成功,若提示缺少某个库文件,请谷歌一下;

# -*- coding: UTF-8 -*-%matplotlib inlineimport osimport skimageimport skimage.io as ioimport skimage.transformimport sys import numpy as npimport mathfrom matplotlib import pyplotimport matplotlib.image as mpimgfrom __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionfrom __future__ import unicode_literalsimport argparseimport numpy as npimport lmdbfrom caffe2.proto import caffe2_pb2from caffe2.python import workspace, model_helperprint("Required modules imported.")     

3. 准备

设置路径,设置标签对应表,限制最大输出文件大小

path = "/home/hw/H/00_dataOfPlate/15_hanzi/01_new_chn/train" #数据路径sep = os.path.sep #当前系统(linux)路径分隔符chn = ["beijing", "tianjin", "hebei", "shanxi", "neimenggu", "liaoning", "jilin", "heilongjiang", "shanghai", "jiangsu", "zhejiang", "anhui", "fujian", "jiangxi", "shandong", "henan", "hubei", "hunan", "guangdong", "guangxi", "hainan", "sichuan", "guizhou", "yunnan", "chongqin", "xizang", "shengxi", "gansu", "qinghai", "ningxia", "xinjiang"] # No. 31 # 标签对应表LMDB_MAP_SIZE =  1099511627776  #max output file < 1TBprint("prepared") 最大输出文件大小

4. write函数:读入图片文件与标签,并转换为mdb文件

文件结构,以train文件夹为例,train下包含26个字母,标签label以图片所在文件夹为准;
- train
- A
- 0001.bmp
- 0002.bmp
- …
- 4000.bmp
- B
- 0001.bmp
- …
- 4100.bmp
- …
- …

一级目录 二级目录 图片 train A 1022.bmp train A … train A 4032.bmp train B 1022.bmp
def write_db_with_caffe2(output_file):    print(">>> Write database ...")    LMDB_MAP_SIZE = 1099511627776    env = lmdb.open(output_file, map_size = LMDB_MAP_SIZE)    checksum = 0    checksumm = 0    j = 0    with env.begin(write = True) as txn:        for dirs in os.listdir(path):        #     print dirs            new_path = path + sep + dirs            label = chn.index(dirs)            for pics in os.listdir(new_path):                #print pics    #             print(len(os.listdir(new_path)))                pic_path = new_path + sep + pics                #print pic_path                img_data = skimage.img_as_float(skimage.io.imread(pic_path)).astype(np.float)                print("before: {}".format(img_data.shape))                img_data = img_data[:,:,:1] #3通道转换为1通道                img_data = img_data.swapaxes(1, 2).swapaxes(0, 1) #HWC 转换为 CHW                print("after: {}".format(img_data.shape))        #         print np.prod(img_data.shape)                tensor_protos = caffe2_pb2.TensorProtos()                img_tensor = tensor_protos.protos.add()                img_tensor.dims.extend(img_data.shape)                img_tensor.data_type = 1                flatten_img = img_data.reshape(np.prod(img_data.shape))                print("after: {}".format(flatten_img.shape))                img_tensor.float_data.extend(flatten_img.flat)                label_tensor = tensor_protos.protos.add()                label_tensor.data_type = 2                label_tensor.int32_data.append(label)                txn.put('{}'.format(j).encode('ascii'),tensor_protos.SerializeToString())    #             print(np.sum(img_data))    #             print(label)                checksum += np.sum(img_data) * label                checksumm += np.sum(img_data)                if(j % 5 == 0):                    pass    #                 print("Inserted {} rows".format(j))                j+=1    #     print(j)        print("Checksum/write: {}".format(int(checksum)))        print("Checksumm/write: {}".format(int(checksumm)))

5. read函数:读取mdb文件,并校验(此步不是必须的)

输入数据所在文件夹:read_db_with_caffe2(db_file, expected_checksum)
db_file: 数据文件所在路径
expected_checksum:期望的输出校验值,应该与write_db_with_caffe2中的值对应

def read_db_with_caffe2(db_file, expected_checksum):    print(">>> Read database...")    model = model_helper.ModelHelper(name="lmdbtest")    batch_size = 744000 #共计多少个文件,一定要写正确,否则会造成校验失败("Read/write checksums dont match")    data, label = model.TensorProtosDBInput(        [], ["data", "label"], batch_size=batch_size,        db=db_file, db_type="lmdb")    checksum = 0     workspace.RunNetOnce(model.param_init_net)    workspace.CreateNet(model.net)    for _ in range(0, 1):         workspace.RunNet(model.net.Proto().name)        img_datas = workspace.FetchBlob("data")        labels = workspace.FetchBlob("label")#         print("batch_size: {}".format(batch_size))#         print(img_data.shape)        for j in range(batch_size):#             print(img_datas[j, 2])            checksum += np.sum(img_datas[j, :]) * labels[j]            checksumm += np.sum(img_datas[j, :])#             print(np.sum(img_datas[j,:]))#             print(labels[j])    print("Checksum/read: {}".format(int(checksum)))    print("minus of read and write: {}".format(np.abs(expected_checksum - checksum )))    assert np.abs(expected_checksum - checksum < 0.1), \        "Read/write checksums dont match"

6. 执行

执行时间较长,请耐心等待,读取744000个大小为20*20的灰度图像时,时间约为二十分钟,读取db数据进行测试,电脑卡死了,o(╯□╰)o;

write_db_with_caffe2("./chn_db") read_db_with_caffe2("./chn_db", 640020532) #640020532为校验值,应该等于write中输出的checksum大小

7. 可能遇见的报错

1. CHW和HWC的问题:

input channels does not match: # of input channels 20 is not equal to kernel channels * group:1*1
原因:默认读取的图片为shape为HWC(height/width/channels),而caffe2默认图片数据格式为CHW,所以需要进行转换,不转换则报错如下:

RuntimeError: [enforce fail at conv_op_impl.h:30] C == filter.dim32(1) * group_. Convolution op: input channels does not match: # of input channels 20 is not equal to kernel channels * group:1*1 Error from operator: input: "data" input: "conv1_w" input: "conv1_b" output: "conv1" name: "" type: "Conv" arg { name: "kernel" i: 5 } arg { name: "exhaustive_search" i: 0 } arg { name: "order" s: "NCHW" } engine: "CUDNN"

解决方式:在将图片转换为mdb文件时,加入img_data = img_data.swapaxes(1, 2).swapaxes(0, 1)(上面的程序中已经加入了)

2. channels不匹配问题:

input channels does not match: # of input channels 3 is not equal to kernel channels * group:1*1
原因:默认读取的图片,不论是否为灰度图,都会以三通道的形式读取,经过上面1. 中的HWC–>> CHW的转换后,通道为3,与MNIST示例LENET中的单通道不匹配,所以报错如下:

RuntimeError: [enforce fail at conv_op_impl.h:30] C == filter.dim32(1) * group_. Convolution op: input channels does not match: # of input channels 3 is not equal to kernel channels * group:1*1 Error from operator:
input: "data" input: "conv1_w" input: "conv1_b" output: "conv1" name: "" type: "Conv" arg { name: "kernel" i: 5 } arg { name: "exhaustive_search" i: 0 } arg { name: "order" s: "NCHW" } engine: "CUDNN"

解决方式:在将图片转换为mdb文件时,加入:img_data = img_data[:,:,:1](上面程序已经加入了)

阅读全文
0 0
原创粉丝点击