图像扩充用于图像目标检测

来源:互联网 发布:淘宝网怎么投诉卖家 编辑:程序博客网 时间:2024/06/04 17:59
常用的图像扩充方式有:
水平翻转,裁剪,视角变换,jpeg压缩,尺度变换,颜色变换,旋转
当用于分类数据集时,这些变换方法可以全部被使用,然而考虑到目标检测标注框的变换,我们选择如下几种方式用于目标检测数据集扩充:
jpeg压缩,尺度变换,颜色变换
这里,我们介绍一个图象变换包
http://lear.inrialpes.fr/people/paulin/projects/ITP/
这是项目主页,里面介绍了用于图像变换的基本方法,以及如何组合它们可以得到最好的效果,项目主页里同时带python程序。

里面的图像变换程序如下(用于windows下,用于目标检测时,做了一些修改):

import os, sys, pdb, numpyfrom PIL import Image,ImageChops,ImageOps,ImageDraw#parameters used for the CVPR paperNCROPS = 10NHOMO = 8JPG=[70,50,30]ROTS = [3,6,9,12,15]SCALES=[1.5**0.5,1.5,1.5**1.5,1.5**2,1.5**2.5]#parameters computed on ILSVRC10 datasetlcolor = [ 381688.61379382 , 4881.28307136,  2316.10313483]pcolor = [[-0.57848371, -0.7915924,   0.19681989],          [-0.5795621 ,  0.22908373, -0.78206676],          [-0.57398987 , 0.56648223 , 0.59129816]]#pre-generated gaussian valuesalphas = [[0.004894 , 0.153527, -0.012182],          [-0.058978, 0.114067, -0.061488],          [0.002428, -0.003576, -0.125031]]def gen_colorimetry(i):    p1r = pcolor[0][0]    p1g = pcolor[1][0]    p1b = pcolor[2][0]    p2r = pcolor[0][1]    p2g = pcolor[1][1]    p2b = pcolor[2][1]    p3r = pcolor[0][2]    p3g = pcolor[1][2]    p3b = pcolor[2][2]    l1 = numpy.sqrt(lcolor[0])    l2 = numpy.sqrt(lcolor[1])    l3 = numpy.sqrt(lcolor[2])    if i<=3:        alpha = alphas[i]    else:        numpy.random.seed(i*3)        alpha = numpy.random.randn(3,0,0.01)    a1 = alpha[0]    a2 = alpha[1]    a3 = alpha[2]    return (a1*l1*p1r + a2*l2*p2r + a3*l3*p3r,            a1*l1*p1g + a2*l2*p2g + a3*l3*p3g,            a1*l1*p1b + a2*l2*p2b + a3*l3*p3b)def gen_crop(i,w,h):    numpy.random.seed(4*i)    x0 = numpy.random.random()*(w/4)    y0 = numpy.random.random()*(h/4)    x1 = w - numpy.random.random()*(w/4)    y1 = h - numpy.random.random()*(h/4)    return (int(x0),int(y0),int(x1),int(y1))def gen_homo(i,w,h):    if i==0:        return (0,0,int(0.125*w),h,int(0.875*w),h,w,0)    elif i==1:      return (0,0,int(0.25*w),h,int(0.75*w),h,w,0)    elif i==2:        return (0,int(0.125*h),0,int(0.875*h),w,h,w,0)    elif i==3:      return (0,int(0.25*h),0,int(0.75*h),w,h,w,0)    elif i==4:        return (int(0.125*w),0,0,h,w,h,int(0.875*w),0)    elif i==5:        return (int(0.25*w),0,0,h,w,h,int(0.75*w),0)    elif i==6:        return (0,0,0,h,w,int(0.875*h),w,int(0.125*h))    elif i==7:        return (0,0,0,h,w,int(0.75*h),w,int(0.25*h))    else:        assert Falsedef rot(image,angle,fname):    white = Image.new('L',image.size,"white")    wr = white.rotate(angle,Image.NEAREST,expand=0)    im = image.rotate(angle,Image.BILINEAR,expand=0)    try:      image.paste(im,wr)    except ValueError:      print >>sys.stderr, 'error: image do not match '+fname    return imagedef gen_corner(n, w, h):    x0 = 0    x1 = w    y0 = 0    y1 = h        rat = 256 - 227    if n == 0: #center        x0 = (rat*w)/(2*256.0)        y0 = (rat*h)/(2*256.0)        x1 = w - (rat*w)/(2*256.0)        y1 = h - (rat*h)/(2*256.0)    elif n == 1:        x0 = (rat*w)/256.0        y0 = (rat*h)/256.0    elif n == 2:        x1 = w - (rat*w)/256.0        y0 = (rat*h)/256.0    elif n == 3:        x1 = w - (rat*w)/256.0        y1 = h - (rat*h)/256.0    else:        assert n==4        x0 = (rat*w)/256.0        y1 = h - (rat*h)/256.0    return (int(x0),int(y0),int(x1),int(y1))#the main fonction to call#takes a image input path, a transformation and an output path and does the transformationdef gen_trans(imgfile,trans,outfile):    for trans in trans.split('*'):        image = Image.open(imgfile)        w,h = image.size        if trans=="plain":            image.save(outfile,"JPEG",quality=100)        elif trans=="flip":            ImageOps.mirror(image).save(outfile,"JPEG",quality=100)        elif trans.startswith("crop"):            c = int(trans[4:])            image.crop(gen_crop(c,w,h)).save(outfile,"JPEG",quality=100)        elif trans.startswith("homo"):            c = int(trans[4:])            image.transform((w,h),Image.QUAD,                            gen_homo(c,w,h),                            Image.BILINEAR).save(outfile,"JPEG",quality=100)        elif trans.startswith("jpg"):            image.save(outfile,quality=int(trans[3:]))        elif trans.startswith("scale"):            scale = SCALES[int(trans.replace("scale",""))]            image.resize((int(w/scale),int(h/scale)),Image.BILINEAR).save(outfile,"JPEG",quality=100)        elif trans.startswith('color'):            (dr,dg,db) = gen_colorimetry(int(trans[5]))            table = numpy.tile(numpy.arange(256),(3))            table[   :256]+= (int)(dr)            table[256:512]+= (int)(dg)            table[512:   ]+= (int)(db)            image.convert("RGB").point(table).save(outfile,"JPEG",quality=100)        elif trans.startswith('rot-'):            angle =int(trans[4:])            for i in range(angle):                image = rot(image,-1,outfile)            image.save(outfile,"JPEG",quality=100)        elif trans.startswith('rot'):            angle =int(trans[3:])            for i in range(angle):                image = rot(image,1,outfile)            image.save(outfile,"JPEG",quality=100)        elif trans.startswith('corner'):            i = int(trans[6:])            image.crop(gen_corner(i,w,h)).save(outfile,"JPEG",quality=100)        else:            assert False, "Unrecognized transformation: "+trans        imgfile = outfile # in case we iterate#Our 41 transformations used in the CVPR paperdef get_all_trans():  # transformations = (["plain","flip"]  #                  # +["crop%d"%i for i in range(NCROPS)]  #                  # +["homo%d"%i for i in range(NHOMO)]  #                   +["jpg%d"%i for i in JPG]  #                   +["scale0","scale1","scale2","scale3","scale4"]  #                   +["color%d"%i for i in range(3)]  #                  # +["rot-%d"%i for i in ROTS]  #                   # +["rot%d"%i for i in ROTS]  # )+["scale0","scale1","scale2","scale3","scale4"]  transformations=(["plain"]                   + ["jpg%d" % i for i in JPG]                   + ["scale0", "scale1", "scale2", "scale3", "scale4"]                   + ["color%d" % i for i in range(3)])  return transformations#transformations used at test time in deep architecturesdef get_deep_trans():    return ['corner0','corner1','corner2','corner3','corner4','corner0*flip','corner1*flip','corner2*flip','corner3*flip','corner4*flip']if __name__=="__main__":    inputpath = sys.argv[1]    name = [name for name in os.listdir(inputpath) if os.path.isfile(os.path.join(inputpath,name))]    #img_input = sys.argv[1]    outpath = sys.argv[2]    if len(sys.argv)>= 4:        trans = sys.argv[3]        if not trans.startswith("["):            trans = [trans]        else:            trans = eval(trans)    else:        trans = get_all_trans()    print "Generating transformations and storing in %s"%(outpath)    for k in name:        for t in trans:            img_input=inputpath+'\\'+k            gen_trans(img_input,t,outpath+'\\%s_%s.jpg'%(".".join(img_input.split("\\")[-1].split(".")[:-1]),t))            #gen_trans(k, t, outpath + '\\%s_%s.jpg' % (".".join(k.split(".")[:-1]), t))    print "Finished. Transformations generated: %s"%(" ".join(trans))

这是变换前的图片:1._7_17.jpg

变换后的图片如下:
1_7_17_color0.jpg

1_7_17_color1.jpg

1_7_17_color2.jpg

1_7_17_jpg30.jpg

1_7_17_jog50.jpg

1_7_17_jpg70.jpg

1_7_17_scale0.jpg

1_7_17_scale1.jpg

1_7_17_scale2.jpg

1_7_17_scale3.jpg

1_7_17_scale4.jpg

用于目标检测时:XML标注文件也需要做相应修改,主要是针对尺度变换:
修改前的xml文件如下(1_7_17.jpg):
<annotation><folder>spl</folder><filename>1_7_17.jpg</filename><source><database>The spl Database</database><annotation>The spl Database</annotation><image>spl</image><flickrid>0</flickrid></source><owner><flickrid>spl</flickrid><name>xiaovv</name></owner><size><width>800</width><height>800</height><depth>3</depth></size><segmented>0</segmented><object><name>aeroplane</name><pose>Unspecified</pose><truncated>0</truncated><difficult>0</difficult><bndbox><xmin>151</xmin><ymin>357</ymin><xmax>212</xmax><ymax>399</ymax></bndbox></object><object><name>aeroplane</name><pose>Unspecified</pose><truncated>0</truncated><difficult>0</difficult><bndbox><xmin>134</xmin><ymin>593</ymin><xmax>193</xmax><ymax>654</ymax></bndbox></object></annotation>
修改后的xml文件如下(1_7_17_scale4.jpg):
<?xml version="1.0" encoding="utf-8"?><annotation><folder>spl</folder><filename>1_7_17_scale4.jpg</filename><source><database>The spl Database</database><annotation>The spl Database</annotation><image>spl</image><flickrid>0</flickrid></source><owner><flickrid>spl</flickrid><name>xiaovv</name></owner><size><width>290</width><height>290</height><depth>3</depth></size><segmented>0</segmented><object><name>aeroplane</name><pose>Unspecified</pose><truncated>0</truncated><difficult>0</difficult><bndbox><xmin>54</xmin><ymin>129</ymin><xmax>76</xmax><ymax>144</ymax></bndbox></object><object><name>aeroplane</name><pose>Unspecified</pose><truncated>0</truncated><difficult>0</difficult><bndbox><xmin>48</xmin><ymin>215</ymin><xmax>70</xmax><ymax>237</ymax></bndbox></object></annotation>
修改xml文件的程序如下;
# -*- coding=utf-8 -*-import osimport sysimport shutilfrom xml.dom.minidom import Documentfrom xml.etree.ElementTree import ElementTree,Elementimport  xml.dom.minidomJPG=[70,50,30]SCALES=[1.5**0.5,1.5,1.5**1.5,1.5**2,1.5**2.5]#产生变换后的xml文件def gen_xml(xml_input,trans,outfile):    for trans in trans.split('*'):        if trans=="plain" or trans.startswith("jpg") or trans.startswith('color'):#如果是这几种变换,直接修改xml文件名就好            dom = xml.dom.minidom.parse(xml_input)            root = dom.documentElement            filenamelist = root.getElementsByTagName('filename')            filename = filenamelist[0]            c = str(filename.firstChild.data)            d = ".".join(outfile.split("\\")[-1].split(".")[:-1]) + '.jpg'            filename.firstChild.data = d            f = open(outfile, 'w')            dom.writexml(f, encoding='utf-8')        elif trans.startswith("scale"):#对于尺度变换,xml文件信息也需要改变            scale = SCALES[int(trans.replace("scale", ""))]            dom=xml.dom.minidom.parse(xml_input)            root=dom.documentElement            filenamelist=root.getElementsByTagName('filename')            filename=filenamelist[0]            c=str(filename.firstChild.data)            d=".".join(outfile.split("\\")[-1].split(".")[:-1])+'.jpg'            filename.firstChild.data=d            heightlist = root.getElementsByTagName('height')            height = heightlist[0]            a = int(height.firstChild.data)            b = str(int(a / scale))            height.firstChild.data = b            widthlist=root.getElementsByTagName('width')            width=widthlist[0]            a = int(width.firstChild.data)            b = str(int(a / scale))            width.firstChild.data=b            objectlist=root.getElementsByTagName('xmin')            for object in objectlist:                a=int(object.firstChild.data)                b=str(int(a/scale))                object.firstChild.data=b            objectlist = root.getElementsByTagName('ymin')            for object in objectlist:                a = int(object.firstChild.data)                b = str(int(a / scale))                object.firstChild.data = b            objectlist = root.getElementsByTagName('xmax')            for object in objectlist:                a = int(object.firstChild.data)                b = str(int(a / scale))                object.firstChild.data = b            objectlist = root.getElementsByTagName('ymax')            for object in objectlist:                a = int(object.firstChild.data)                b = str(int(a / scale))                object.firstChild.data = b            f=open(outfile,'w')            dom.writexml(f,encoding='utf-8')        else:            assert False, "Unrecognized transformation: "+trans#产生各种变换名def get_all_trans():  transformations=(["plain"]                   + ["jpg%d" % i for i in JPG]                   + ["scale0", "scale1", "scale2", "scale3", "scale4"]                   + ["color%d" % i for i in range(3)])  return transformationsif __name__=="__main__":    inputpath = sys.argv[1]    name = [name for name in os.listdir(inputpath) if os.path.isfile(os.path.join(inputpath,name))]    outpath = sys.argv[2]    if len(sys.argv)>= 4:        trans = sys.argv[3]        if not trans.startswith("["):            trans = [trans]        else:            trans = eval(trans)    else:        trans = get_all_trans()    print "Generating transformations and storing in %s"%(outpath)    for k in name:        for t in trans:            xml_input=inputpath+'\\'+k            gen_xml(xml_input,t,outpath+'\\%s_%s.xml'%(".".join(xml_input.split("\\")[-1].split(".")[:-1]),t))