Mxnet图片分类(1)准备数据集
来源:互联网 发布:道路网络拓扑关系构建 编辑:程序博客网 时间:2024/05/22 03:48
Mxnet做图片分类训练的时候提供了多种输入,这里介绍.rec
数据的制作。
mxnet版本:0.904
1.收集图片
这里收集了两类图片,猫和狗的图片。
第一类
第二类
所有的图片放到两个文件夹中
2.生成mxnet训练需要的数据格式
如果mxnet自带的im2rec.py有问题,这里用的是另外的文件,文件名叫im2rec2.py,代码在最后面。
#生成.lst文件python im2rec2.py --list True --exts .jpg --exts .jpeg --recusive=True --train-ratio 0.7 --test-ratio 0.3 train-cat ./train-cat
说明: python im2rec2.py 参数1 参数2 参数3 参数4 参数5 参数6 参数7... 参数1 --list:--list True 创建make_list 参数2 --exts: 需要查找的图像的扩展名 参数3 --train-ratio : 使用训练的图片的比例 参数4 --test-trtio : 使用测试的图片的比例 参数5 --recusive : 如果是true 就会递归寻找子目录的文件 参数6 train-cat : 输出lst和rec文件的前缀 参数7 ./train-cat : 包含图像文件夹目录的根目录
#生成.rec,im2rec需要编译生成,可以参考文献[1]../bin/im2rec ./train-cat_train.lst train-cat/ train-cat_train.rec resize=224
说明:2. ../bin/im2rec 参数1 参数2 参数3 参数4 ... 参数1:train.lst,test.lst, or val.lst 参数2: 目录为trian-cat,则该参数为train-cat。这里还要和.lst文件结合 参数3:生成的文件名称 如:image.rec 后面的参数根据需求添加
train-cat_test.lst的生成方法一样,这里忽略了,同时也忽略了val数据的生存
# -*- coding: utf-8 -*-from __future__ import print_functionimport osimport syscurr_path = os.path.abspath(os.path.dirname(__file__))sys.path.append(os.path.join(curr_path, "../python"))import mxnet as mximport randomimport argparseimport cv2import timedef list_image(root, recursive, exts): print (root) print('exts ',exts) image_list = [] if recursive: cat = {} for path, subdirs, files in os.walk(root, followlinks=True): subdirs.sort() print(len(cat), path) for fname in files: fpath = os.path.join(path, fname) #print ('ddd :',fpath) suffix = os.path.splitext(fname)[1].lower() #print(suffix) if os.path.isfile(fpath) and (suffix in exts): if path not in cat: cat[path] = len(cat) yield (len(image_list), os.path.relpath(fpath, root), cat[path]) else: for fname in os.listdir(root): fpath = os.path.join(root, fname) suffix = os.path.splitext(fname)[1].lower() if os.path.isfile(fpath) and (suffix in exts): yield (len(image_list), os.path.relpath(fpath, root), 0)def write_list(path_out, image_list): with open(path_out, 'w') as fout: for i, item in enumerate(image_list): line = '%d\t' % item[0] for j in item[2:]: line += '%f\t' % j line += '%s\n' % item[1] fout.write(line)def make_list(args): image_list = list_image(args.root, args.recursive, args.exts) image_list = list(image_list) if args.shuffle is True: random.seed(100) random.shuffle(image_list) N = len(image_list) chunk_size = (N + args.chunks - 1) / args.chunks for i in xrange(args.chunks): chunk = image_list[i * chunk_size:(i + 1) * chunk_size] if args.chunks > 1: str_chunk = '_%d' % i else: str_chunk = '' sep = int(chunk_size * args.train_ratio) sep_test = int(chunk_size * args.test_ratio) write_list(args.prefix + str_chunk + '_test.lst', chunk[:sep_test]) write_list(args.prefix + str_chunk + '_train.lst', chunk[sep_test:sep_test + sep]) write_list(args.prefix + str_chunk + '_val.lst', chunk[sep_test + sep:])def read_list(path_in): with open(path_in) as fin: while True: line = fin.readline() if not line: break line = [i.strip() for i in line.strip().split('\t')] item = [int(line[0])] + [line[-1]] + [float(i) for i in line[1:-1]] yield itemdef image_encode(args, item, q_out): try: img = cv2.imread(os.path.join(args.root, item[1]), args.color) except: print('imread error:', item[1]) return if img is None: print('read none error:', item[1]) return if args.center_crop: if img.shape[0] > img.shape[1]: margin = (img.shape[0] - img.shape[1]) / 2; img = img[margin:margin + img.shape[1], :] else: margin = (img.shape[1] - img.shape[0]) / 2; img = img[:, margin:margin + img.shape[0]] if args.resize: if img.shape[0] > img.shape[1]: newsize = (args.resize, img.shape[0] * args.resize / img.shape[1]) else: newsize = (img.shape[1] * args.resize / img.shape[0], args.resize) img = cv2.resize(img, newsize) if len(item) > 3 and args.pack_label: header = mx.recordio.IRHeader(0, item[2:], item[0], 0) else: header = mx.recordio.IRHeader(0, item[2], item[0], 0) try: s = mx.recordio.pack_img(header, img, quality=args.quality, img_fmt=args.encoding) q_out.put((s, item)) except Exception, e: print('pack_img error:', item[1], e) returndef read_worker(args, q_in, q_out): while True: item = q_in.get() if item is None: break image_encode(args, item, q_out)def write_worker(q_out, fname, working_dir): pre_time = time.time() count = 0 fname_rec = os.path.basename(fname) fname_rec = os.path.splitext(fname)[0] + '.rec' fout = open(fname+'.tmp', 'w') record = mx.recordio.MXRecordIO(os.path.join(working_dir, fname_rec), 'w') while True: deq = q_out.get() if deq is None: break s, item = deq record.write(s) line = '%d\t' % item[0] for j in item[2:]: line += '%f\t' % j line += '%s\n' % item[1] fout.write(line) if count % 1000 == 0: cur_time = time.time() print('time:', cur_time - pre_time, ' count:', count) pre_time = cur_time count += 1 os.rename(fname+'.tmp', fname)def parse_args(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, description='Create an image list or \ make a record database by reading from an image list') parser.add_argument('prefix', help='prefix of input/output lst and rec files.') parser.add_argument('root', help='path to folder containing images.') cgroup = parser.add_argument_group('Options for creating image lists') cgroup.add_argument('--list', type=bool, default=False, help='If this is set im2rec will create image list(s) by traversing root folder\ and output to <prefix>.lst.\ Otherwise im2rec will read <prefix>.lst and create a database at <prefix>.rec') cgroup.add_argument('--exts', type=list,action='append',default=['.jpeg', '.jpg','.bmp'], help='list of acceptable image extensions.') cgroup.add_argument('--chunks', type=int, default=1, help='number of chunks.') cgroup.add_argument('--train-ratio', type=float, default=1.0, help='Ratio of images to use for training.') cgroup.add_argument('--test-ratio', type=float, default=0, help='Ratio of images to use for testing.') cgroup.add_argument('--recursive', type=bool, default=False, help='If true recursively walk through subdirs and assign an unique label\ to images in each folder. Otherwise only include images in the root folder\ and give them label 0.') rgroup = parser.add_argument_group('Options for creating database') rgroup.add_argument('--resize', type=int, default=0, help='resize the shorter edge of image to the newsize, original images will\ be packed by default.') rgroup.add_argument('--center-crop', type=bool, default=False, help='specify whether to crop the center image to make it rectangular.') rgroup.add_argument('--quality', type=int, default=80, help='JPEG quality for encoding, 1-100; or PNG compression for encoding, 1-9') rgroup.add_argument('--num-thread', type=int, default=1, help='number of thread to use for encoding. order of images will be different\ from the input list if >1. the input list will be modified to match the\ resulting order.') rgroup.add_argument('--color', type=int, default=1, choices=[-1, 0, 1], help='specify the color mode of the loaded image.\ 1: Loads a color image. Any transparency of image will be neglected. It is the default flag.\ 0: Loads image in grayscale mode.\ -1:Loads image as such including alpha channel.') rgroup.add_argument('--encoding', type=str, default='.jpg', choices=['.jpg', '.png'], help='specify the encoding of the images.') rgroup.add_argument('--shuffle', default=True, help='If this is set as True, \ im2rec will randomize the image order in <prefix>.lst') rgroup.add_argument('--pack-label', default=False, help='Whether to also pack multi dimensional label in the record file') args = parser.parse_args() args.prefix = os.path.abspath(args.prefix) args.root = os.path.abspath(args.root) return argsif __name__ == '__main__': args = parse_args() if args.list: make_list(args) print('ddd',args.prefix) else: if os.path.isdir(args.prefix): working_dir = args.prefix else: working_dir = os.path.dirname(args.prefix) files = [os.path.join(working_dir, fname) for fname in os.listdir(working_dir) if os.path.isfile(os.path.join(working_dir, fname))] count = 0 for fname in files: if fname.startswith(args.prefix) and fname.endswith('.lst'): print('Creating .rec file from', fname, 'in', working_dir) count += 1 image_list = read_list(fname) # -- write_record -- # try: import multiprocessing q_in = [multiprocessing.Queue(1024) for i in range(args.num_thread)] q_out = multiprocessing.Queue(1024) read_process = [multiprocessing.Process(target=read_worker, args=(args, q_in[i], q_out)) \ for i in range(args.num_thread)] for p in read_process: p.start() write_process = multiprocessing.Process(target=write_worker, args=(q_out, fname, working_dir)) write_process.start() for i, item in enumerate(image_list): q_in[i % len(q_in)].put(item) for q in q_in: q.put(None) for p in read_process: p.join() q_out.put(None) write_process.join() except ImportError: print('multiprocessing not available, fall back to single threaded encoding') import Queue q_out = Queue.Queue() fname_rec = os.path.basename(fname) fname_rec = os.path.splitext(fname)[0] + '.rec' record = mx.recordio.MXRecordIO(os.path.join(working_dir, fname_rec), 'w') cnt = 0 pre_time = time.time() for item in image_list: image_encode(args, item, q_out) if q_out.empty(): continue _, s, _ = q_out.get() record.write(s) if cnt % 1000 == 0: cur_time = time.time() print('time:', cur_time - pre_time, ' count:', cnt) pre_time = cur_time cnt += 1 if not count: print('Did not find and list file with prefix %s'%args.prefix)
参考文献:
[1] http://mxnet.io/api/scala/io.html?highlight=im2rec
阅读全文
0 0
- Mxnet图片分类(1)准备数据集
- MXNET笔记(二)准备数据
- Mxnet图片分类(2)训练模型
- Mxnet图片分类(3)fine-tune
- 深度学习遥感影像分类之数据集批量准备
- mxnet实战笔记(1) - 使用自己的图片数据训练CNN模型
- mxnet 图像分类
- mxnet大规模图像分类
- mxnet 使用自己的图片数据训练CNN模型
- MXNet学习1——数据模拟
- Mxnet图片分类(4)利用训练好的模型进行测试
- MXNet 中文教程:图像分类
- caffe 跑自己的图像分类任务(1) 之 准备数据
- MXNet数据加载
- MXNet数据生成
- Windows+caffe+libsvm对图片数据集的分类
- Windows+caffe+libsvm对图片数据集的分类
- imagenet 数据集准备
- [bzoj3931][CQOI2015]网络吞吐量 spfa+最大流
- Android 反编译
- FCC-----------Truncate a string
- 知识杂碎
- ireport3.7的scriptlet脚本使用
- Mxnet图片分类(1)准备数据集
- 七天学会ASP.NET MVC (三)——ASP.Net MVC 数据处理
- (懒人必备)Android开源数据库LitePal
- jquery控制input只能输入数字和两位小数(转)
- 获取cpu序列号
- IOC容器的初始化
- 博客暂时停更说明
- 排序!
- easyui tree 取消选择