深度学习-Overfitting-数据增强Data Augmentation
来源:互联网 发布:超时空战士优化排期 编辑:程序博客网 时间:2024/05/05 13:15
深度学习-Overfitting解决方法:
人工增加训练集的大小. 通过平移, 翻转, 加噪声等方法从已有数据中创造出一批”新”的数据.也就是Data Augmentation
Regularization. 数据量比较小会导致模型过拟合, 使得训练误差很小而测试误差特别大. 通过在Loss Function 后面加上正则项可以抑制过拟合的产生. 缺点是引入了一个需要手动调整的hyper-parameter. 详见 https://www.wikiwand.com/en/Regularization_(mathematics)
Dropout. 这也是一种正则化手段. 不过跟以上不同的是它通过随机将部分神经元的输出置零来实现. 详见 http://www.cs.toronto.edu/~hinton/absps/JMLRdropout.pdf
Unsupervised Pre-training. 用Auto-Encoder或者RBM的卷积形式一层一层地做无监督预训练, 最后加上分类层做有监督的Fine-Tuning. 参考 http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.207.1102&rep=rep1&type=pdf
海康威视研究院ImageNet2016竞赛经验分享,其中包括数据增广方法Label shuffle::https://zhuanlan.zhihu.com/p/23249000
在深度学习中,为了避免出现过拟合(Overfitting),通常我们需要输入充足的数据量.为了得到更加充足的数据,我们通常需要对原有的图像数据进行几何变换,改变图像像素的位置并保证特征不变.
数据增强变换(Data Augmentation Transformation)
1 随机裁剪random crop;Random Crop:采用随机图像差值方式,对图像进行裁剪、缩放;包括Scale Jittering方法(VGG及ResNet模型使用)或者尺度和长宽比增强变换;2 旋转 | 反射变换(Rotation/reflection)。镜像变换(random mirror): 随机旋转图像一定角度; 改变图像内容的朝向;3 翻转变换(flip): 沿着水平或者垂直方向翻转图像;4 缩放变换(zoom),random resize (scale augmentation): 按照一定的比例放大或者缩小图像;5 平移变换(shift): 在图像平面上对图像以一定方式进行平移;6 可以采用随机或人为定义的方式指定平移范围和平移步长, 沿水平或竖直方向进行平移. 改变图像内容的位置;7 尺度变换(scale): 对图像按照指定的尺度因子, 进行放大或缩小; 或者参照SIFT特征提取思想, 利用指定的尺度因子对图像滤波构造尺度空间. 改变图像内容的大小或模糊程度;8 对比度变换(contrast): 在图像的HSV颜色空间,改变饱和度S和V亮度分量,保持色调H不变. 对每个像素的S和V分量进行指数运算(指数因子在0.25到4之间), 增加光照变化;9 噪声扰动(noise): 对图像的每个像素RGB进行随机扰动, 常用的噪声模式是椒盐噪声和高斯噪声;10 颜色变换(color): 在训练集像素值的RGB颜色空间进行PCA, 得到RGB空间的3个主方向向量;11 RGB转到HSV空间,然后改变SV而保持H不变的方法,就是改变光照而不改变色调。 对每个像素的S和V分量进行指数运算;12 Color Jittering:对颜色的数据增强:图像亮度、饱和度、对比度变化(此处对色彩抖动的理解不知是否得当);13 PCA Jittering:首先按照RGB三个颜色通道计算均值和标准差,再在整个训练集上计算协方差矩阵,进行特征分解,得到特征向量和特征值,用来做PCA Jittering;14
python脚本:
# -*- coding:utf-8 -*-"""数据增强 1. 翻转变换 flip 2. 随机修剪 random crop 3. 色彩抖动 color jittering 4. 平移变换 shift 5. 尺度变换 scale 6. 对比度变换 contrast 7. 噪声扰动 noise 8. 旋转变换/反射变换 Rotation/reflection"""from PIL import Image, ImageEnhance, ImageOps, ImageFileimport numpy as npimport randomimport threading, os, timeimport logginglogger = logging.getLogger(__name__)ImageFile.LOAD_TRUNCATED_IMAGES = Trueclass DataAugmentation: """ 包含数据增强的八种方式 """ def __init__(self): pass @staticmethod def openImage(image): return Image.open(image, mode="r") @staticmethod def randomRotation(image, mode=Image.BICUBIC): """ 对图像进行随机任意角度(0~360度)旋转 :param mode 邻近插值,双线性插值,双三次B样条插值(default) :param image PIL的图像image :return: 旋转转之后的图像 """ random_angle = np.random.randint(1, 360) return image.rotate(random_angle, mode) @staticmethod def randomCrop(image): """ 对图像随意剪切,考虑到图像大小范围(68,68),使用一个一个大于(36*36)的窗口进行截图 :param image: PIL的图像image :return: 剪切之后的图像 """ image_width = image.size[0] image_height = image.size[1] crop_win_size = np.random.randint(40, 68) random_region = ( (image_width - crop_win_size) >> 1, (image_height - crop_win_size) >> 1, (image_width + crop_win_size) >> 1, (image_height + crop_win_size) >> 1) return image.crop(random_region) @staticmethod def randomColor(image): """ 对图像进行颜色抖动 :param image: PIL的图像image :return: 有颜色色差的图像image """ random_factor = np.random.randint(0, 31) / 10. # 随机因子 color_image = ImageEnhance.Color(image).enhance(random_factor) # 调整图像的饱和度 random_factor = np.random.randint(10, 21) / 10. # 随机因子 brightness_image = ImageEnhance.Brightness(color_image).enhance(random_factor) # 调整图像的亮度 random_factor = np.random.randint(10, 21) / 10. # 随机因1子 contrast_image = ImageEnhance.Contrast(brightness_image).enhance(random_factor) # 调整图像对比度 random_factor = np.random.randint(0, 31) / 10. # 随机因子 return ImageEnhance.Sharpness(contrast_image).enhance(random_factor) # 调整图像锐度 @staticmethod def randomGaussian(image, mean=0.2, sigma=0.3): """ 对图像进行高斯噪声处理 :param image: :return: """ def gaussianNoisy(im, mean=0.2, sigma=0.3): """ 对图像做高斯噪音处理 :param im: 单通道图像 :param mean: 偏移量 :param sigma: 标准差 :return: """ for _i in range(len(im)): im[_i] += random.gauss(mean, sigma) return im # 将图像转化成数组 img = np.asarray(image) img.flags.writeable = True # 将数组改为读写模式 width, height = img.shape[:2] img_r = gaussianNoisy(img[:, :, 0].flatten(), mean, sigma) img_g = gaussianNoisy(img[:, :, 1].flatten(), mean, sigma) img_b = gaussianNoisy(img[:, :, 2].flatten(), mean, sigma) img[:, :, 0] = img_r.reshape([width, height]) img[:, :, 1] = img_g.reshape([width, height]) img[:, :, 2] = img_b.reshape([width, height]) return Image.fromarray(np.uint8(img)) @staticmethod def saveImage(image, path): image.save(path)def makeDir(path): try: if not os.path.exists(path): if not os.path.isfile(path): # os.mkdir(path) os.makedirs(path) return 0 else: return 1 except Exception, e: print str(e) return -2def imageOps(func_name, image, des_path, file_name, times=5): funcMap = {"randomRotation": DataAugmentation.randomRotation, "randomCrop": DataAugmentation.randomCrop, "randomColor": DataAugmentation.randomColor, "randomGaussian": DataAugmentation.randomGaussian } if funcMap.get(func_name) is None: logger.error("%s is not exist", func_name) return -1 for _i in range(0, times, 1): new_image = funcMap[func_name](image) DataAugmentation.saveImage(new_image, os.path.join(des_path, func_name + str(_i) + file_name))opsList = {"randomRotation", "randomCrop", "randomColor", "randomGaussian"}def threadOPS(path, new_path): """ 多线程处理事务 :param src_path: 资源文件 :param des_path: 目的地文件 :return: """ if os.path.isdir(path): img_names = os.listdir(path) else: img_names = [path] for img_name in img_names: print img_name tmp_img_name = os.path.join(path, img_name) if os.path.isdir(tmp_img_name): if makeDir(os.path.join(new_path, img_name)) != -1: threadOPS(tmp_img_name, os.path.join(new_path, img_name)) else: print 'create new dir failure' return -1 # os.removedirs(tmp_img_name) elif tmp_img_name.split('.')[1] != "DS_Store": # 读取文件并进行操作 image = DataAugmentation.openImage(tmp_img_name) threadImage = [0] * 5 _index = 0 for ops_name in opsList: threadImage[_index] = threading.Thread(target=imageOps, args=(ops_name, image, new_path, img_name,)) threadImage[_index].start() _index += 1 time.sleep(0.2)if __name__ == '__main__': threadOPS("/home/pic-image/train/12306train", "/home/pic-image/train/12306train3")
旋转图像并修改对应的xml文件
http://blog.csdn.net/u014540717/article/details/53301195
import cv2import mathimport numpy as npimport os# pdb仅仅用于调试,不用管它import pdb#旋转图像的函数def rotate_image(src, angle, scale=1.): w = src.shape[1] h = src.shape[0] # 角度变弧度 rangle = np.deg2rad(angle) # angle in radians # now calculate new image width and height nw = (abs(np.sin(rangle)*h) + abs(np.cos(rangle)*w))*scale nh = (abs(np.cos(rangle)*h) + abs(np.sin(rangle)*w))*scale # ask OpenCV for the rotation matrix rot_mat = cv2.getRotationMatrix2D((nw*0.5, nh*0.5), angle, scale) # calculate the move from the old center to the new center combined # with the rotation rot_move = np.dot(rot_mat, np.array([(nw-w)*0.5, (nh-h)*0.5,0])) # the move only affects the translation, so update the translation # part of the transform rot_mat[0,2] += rot_move[0] rot_mat[1,2] += rot_move[1] # 仿射变换 return cv2.warpAffine(src, rot_mat, (int(math.ceil(nw)), int(math.ceil(nh))), flags=cv2.INTER_LANCZOS4)//对应修改xml文件def rotate_xml(src, xmin, ymin, xmax, ymax, angle, scale=1.): w = src.shape[1] h = src.shape[0] rangle = np.deg2rad(angle) # angle in radians # now calculate new image width and height # 获取旋转后图像的长和宽 nw = (abs(np.sin(rangle)*h) + abs(np.cos(rangle)*w))*scale nh = (abs(np.cos(rangle)*h) + abs(np.sin(rangle)*w))*scale # ask OpenCV for the rotation matrix rot_mat = cv2.getRotationMatrix2D((nw*0.5, nh*0.5), angle, scale) # calculate the move from the old center to the new center combined # with the rotation rot_move = np.dot(rot_mat, np.array([(nw-w)*0.5, (nh-h)*0.5,0])) # the move only affects the translation, so update the translation # part of the transform rot_mat[0,2] += rot_move[0] rot_mat[1,2] += rot_move[1] # rot_mat是最终的旋转矩阵 # 获取原始矩形的四个中点,然后将这四个点转换到旋转后的坐标系下 point1 = np.dot(rot_mat, np.array([(xmin+xmax)/2, ymin, 1])) point2 = np.dot(rot_mat, np.array([xmax, (ymin+ymax)/2, 1])) point3 = np.dot(rot_mat, np.array([(xmin+xmax)/2, ymax, 1])) point4 = np.dot(rot_mat, np.array([xmin, (ymin+ymax)/2, 1])) # 合并np.array concat = np.vstack((point1, point2, point3, point4)) # 改变array类型 concat = concat.astype(np.int32) print concat rx, ry, rw, rh = cv2.boundingRect(concat) return rx, ry, rw, rh# 使图像旋转60,90,120,150,210,240,300度for angle in (60, 90, 120, 150, 210, 240, 300): # 指向图片所在的文件夹 for i in os.listdir("/home/username/image"): # 分离文件名与后缀 a, b = os.path.splitext(i) # 如果后缀名是“.jpg”就旋转图像 if b == ".jpg": img_path = os.path.join("/home/username/image", i) img = cv2.imread(img_path) rotated_img = rotate_image(img, angle) # 写入图像 cv2.imwrite("/home/yourname/rotate/" + a + "_" + str(angle) +"d.jpg", rotated_img) print "log: [%sd] %s is processed." % (angle, i) else: xml_path = os.path.join("/home/username/xml", i) img_path = "/home/guoyana/varied_pose/" + a + ".jpg" src = cv2.imread(img_path) tree = ET.parse(xml_path) root = tree.getroot() for box in root.iter('bndbox'): xmin = float(box.find('xmin').text) ymin = float(box.find('ymin').text) xmax = float(box.find('xmax').text) ymax = float(box.find('ymax').text) x, y, w, h = rotate_xml(src, xmin, ymin, xmax, ymax, angle) # 改变xml中的人脸坐标值 box.find('xmin').text = str(x) box.find('ymin').text = str(y) box.find('ymax').text = str(x+w) box.find('ymax').text = str(y+h) box.set('updated', 'yes') # 写入新的xml tree.write("/home/username/xml/" + a + "_" + str(angle) +".xml") print "[%s] %s is processed." % (angle, i)
- 深度学习-Overfitting-数据增强Data Augmentation
- DATA Augmentation 数据增强
- 数据增强 data augmentation
- 数据增强(data Augmentation)
- CNN Data Augmentation(数据增强)-旋转
- data augmentation 数据增强方法总结
- data augmentation 数据增强方法总结
- 【转】data augmentation 数据增强方法总结
- 深度学习中的Data Augmentation方法
- 02-深度学习中的Data Augmentation方法
- 深度学习样本生成data augmentation
- 图片的数据增强(Data Augmentation)方法
- 深度学习中的Data Augmentation和代码实现
- 深度学习中的Data Augmentation方法(转)基于keras
- 深度学习中的Data Augmentation方法(转)
- 深度学习避免过拟合的方法---Data Augmentation
- 深度学习--通过正则化regularization防止overfitting
- 深度学习--数据增强
- YOLOv2训练自己的数据集
- Python----Pyc, 模块
- 【Java团队用OpenResty】4、Redis的高可用性
- windows10 搭建 NTP 时间服务器
- 如何将营业执照完整保留为Word格式?
- 深度学习-Overfitting-数据增强Data Augmentation
- 从 SQLite 3导出数据
- [PAT]1003. 我要通过! (Python)
- Cloneable,Comparable,Comparator接口
- Git分支指针移动到不同的提交
- FL2440开发板介绍及其烧录
- PAT考试乙级1021(C语言实现)
- springmvc项目(一)
- android 兼容性测试 CTS 测试过程(实践测试验证通过)