caffe python lmdb

来源:互联网 发布:群晖网络存储 编辑:程序博客网 时间:2024/05/21 06:25
import numpy as np
import os
import matplotlib.pyplot as plt
import lmdb
from PIL import Image
import random
import sys
caffe_root = '/home/tsq/Documents/project/mcnn/'
sys.path.insert(0, caffe_root + 'python')
import caffe
train_file = open('test_b.txt')
inputs_data_train = train_file.readlines()
train_file.close()
print("Creating Training Data LMDB File ..... ")
in_db = lmdb.open('Train_Data_lmdb',map_size=int(1e12))
with in_db.begin(write=True) as in_txn:
    for in_idx, in_ in enumerate(inputs_data_train):
        # print in_idx
        in_ = in_.strip()
        in_='test_b/'+in_
        im = np.array(Image.open(in_))
        Dtype = im.dtype
        if len(im.shape) == 2:
            print('here')
            (row, col) = im.shape
            im3 = np.zeros([row, col, 3], Dtype)
            for i in range(3):
                im3 [:, :, i] = im
            im = im3
            print('here')
        
        im = im[:,:,::-1]
        im = Image.fromarray(im)
        im = np.array(im,Dtype)
        im = im.transpose((2,0,1))
        im_dat = caffe.io.array_to_datum(im)
        in_txn.put('{:0>10d}'.format(in_idx),im_dat.SerializeToString())
in_db.close()

import scipy.io as sio
label_file = open('test_b_gt.txt')
inputs_data_label = label_file.readlines()
label_file.close()

print("Creating Training Label LMDB File ..... ")
in_db1 = lmdb.open('Label_Data_lmdb',map_size=int(1e12))
with in_db1.begin(write=True) as in_txn:
    for in_idx, in_ in enumerate(inputs_data_label):
        in_ = in_.strip()
        in_='gt2/'+in_
        Dtype = 'double'
        L = sio.loadmat(in_)
        L = np.array(L['d_map'], Dtype)
        L=np.expand_dims(L,axis=0)
        L_dat = caffe.io.array_to_datum(L,0)
        
        in_txn.put('{:0>10d}'.format(in_idx),L_dat.SerializeToString())
in_db1.close()

print("Finish creating lmdb file ......")
0 0
原创粉丝点击