利用keras进行样本扩充

来源:互联网 发布:黑客网络hacknet攻略 编辑:程序博客网 时间:2024/05/21 01:29

1.首先安装tensorflow和keras(建议通过conda进行安装),网上有好多教程,以下面的教程为例

http://www.linuxidc.com/Linux/2016-07/133214.htm

2.将要扩充的样本放在data目录下的train文件夹下,并在data文件夹下新建一个preview的文件夹用来存放扩充后的样本,具体目录如下:

 test/data/preview <and> train

train目录下为需要扩充的样本文件夹:如bus,flower,horse等等

3.其具体代码如下,只有30多行

empty#-*- coding:utf-8 -*-import osfrom keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_imgdatagen = ImageDataGenerator(        rotation_range=40,        width_shift_range=0.2,        height_shift_range=0.2,        shear_range=0.2,        zoom_range=0.2,        horizontal_flip=True,        fill_mode='nearest')def expand(lable):j = 1if not os.path.isdir('/home/wyq/test/data/preview/'+lable): #判断preview目录下是否存在××该文件夹,os.mkdir('/home/wyq/test/data/preview/'+lable) #若不存在则创建一个文件夹来保存扩充后的样本for file_name in os.listdir('/home/wyq/test/data/train/'+lable): #要扩充的图片所在目录img = load_img('/home/wyq/test/data/train/'+lable+'/'+file_name) #this is a PIL imagex = img_to_array(img)  # this is a Numpy array with shape (3,150,150)x = x.reshape((1,) + x.shape)  # this is a Numpy array with shape (1,3,150,150)i = 1for batch in datagen.flow(x, batch_size=1, save_to_dir='/home/wyq/test/data/preview/'+lable, save_prefix=lable, save_format='jpg'):#设置扩充后的样本保存位置及属性i += 1if i > 10:  #每张图片扩充10张break  # otherwise the generator would loop indefinitelyj +=1if j>100: #每个文件夹中有100张图片,故遍历100次breakexpand("bus")# bus为车的图片的文件夹名称expand("dinosaur")#dinosaur为恐龙图片的文件夹名称expand("elephant")#同上expand("flower")#同上expand("horse")#同上