风格转换简介
来源:互联网 发布:高级软件开发工程师 编辑:程序博客网 时间:2024/05/21 11:13
- 风格转换
- 优化问题
- 综述
- 损失函数
- 训练
- 例子
- 代码
- 网络转换
- 结构
- 训练
- 参考
风格转换,是把一张图片转化成同内容但包含某风格的新图片。本文将介绍如何让机器学习风格转换,包含两种方法:优化问题求解、转化网络求解。
风格转换
风格转换,就是根据现有的风格照片
本文将叙述两种风格转换的思路:
- 将风格转换变成优化问题的求解,构建
T,C 之间的损失Lc 以及T,S 之间的损失Ls ,同时增加图片平滑的损失Lv 。通过求解minT∑iLi 的优化问题求解。 - 不直接把目标图片
T 当做求解变量,而是构建一个transform network
把内容图片C 转化成目标图片T ,以类似1中的方法构建损失函数,通过求解transform network
的参数求解该问题。
优化问题
综述
首先,陈述问题:假设已知风格照片
下面,确定几个损失函数:
Ls :T 和S 风格上的距离Lc :T 和C 内容上的距离Lv :T 不平滑的度量
最后,便是求解优化问题:
损失函数
优化问题中
首先,简单介绍下VGG网络:它是一种固定的网络结构,其结构如下所示,一般采用D或E结构,通常叫VGG-16和VGG-19:
那么,为什么
训练后的VGG网络,每一层都对特征进行了抽象,越深得到的特征越具象。所以每一层的特征也就代表了图片不同粒度的抽象,可以根据特征的距离判断图片内容的相似程度。VGG的卷积层得到了feature map
,假设其大小是
假设在层feature map
是feature map
是
假设在层feature map
是feature map
是
其中,gram matrix
,feature map
上不同feature
的相互作用关系,用其来度量风格。
至于
训练
构建好损失函数
这里优化问题的求解方法采用L-BFGS
(一种伪牛顿法),这样做的目的是得到比gradient descent更快的收敛速度。
例子
本人是詹姆斯的铁杆球迷,对詹姆斯的照片采用不同风格转换后的效果图如下所示。需要说明的是:第二列第一张是未加平滑损失
代码
以下代码参考了Siraj Raval on YouTube
# Load libraryfrom __future__ import print_functionimport timefrom PIL import Imageimport numpy as npfrom keras import backendfrom keras.models import Modelfrom keras.applications.vgg16 import VGG16from scipy.optimize import fmin_l_bfgs_bfrom scipy.misc import imsave# Load and preprocess the content and style imagesheight = 512width = 512content_image_path = 'images/hugo.jpg'content_image = Image.open(content_image_path)content_image = content_image.resize((height, width))content_imagestyle_image_path = 'images/styles/wave.jpg'style_image = Image.open(style_image_path)style_image = style_image.resize((height, width))style_imagecontent_array = np.asarray(content_image, dtype='float32')content_array = np.expand_dims(content_array, axis=0)print(content_array.shape)style_array = np.asarray(style_image, dtype='float32')style_array = np.expand_dims(style_array, axis=0)print(style_array.shape)content_array[:, :, :, 0] -= 103.939content_array[:, :, :, 1] -= 116.779content_array[:, :, :, 2] -= 123.68content_array = content_array[:, :, :, ::-1]style_array[:, :, :, 0] -= 103.939style_array[:, :, :, 1] -= 116.779style_array[:, :, :, 2] -= 123.68style_array = style_array[:, :, :, ::-1]content_image = backend.variable(content_array)style_image = backend.variable(style_array)combination_image = backend.placeholder((1, height, width, 3))input_tensor = backend.concatenate([content_image, style_image, combination_image], axis=0)# Reuse a model pre-trained for image classification to define loss functionsmodel = VGG16(input_tensor=input_tensor, weights='imagenet', include_top=False)layers = dict([(layer.name, layer.output) for layer in model.layers])content_weight = 0.025style_weight = 5.0total_variation_weight = 1.0# Lossloss = backend.variable(0.)# The content lossdef content_loss(content, combination): return backend.sum(backend.square(combination - content))layer_features = layers['block2_conv2']content_image_features = layer_features[0, :, :, :]combination_features = layer_features[2, :, :, :]loss += content_weight * content_loss(content_image_features, combination_features)# The style lossdef gram_matrix(x): features = backend.batch_flatten(backend.permute_dimensions(x, (2, 0, 1))) gram = backend.dot(features, backend.transpose(features)) return gramdef style_loss(style, combination): S = gram_matrix(style) C = gram_matrix(combination) channels = 3 size = height * width return backend.sum(backend.square(S - C)) / (4. * (channels ** 2) * (size ** 2))feature_layers = ['block1_conv2', 'block2_conv2', 'block3_conv3', 'block4_conv3', 'block5_conv3']for layer_name in feature_layers: layer_features = layers[layer_name] style_features = layer_features[1, :, :, :] combination_features = layer_features[2, :, :, :] sl = style_loss(style_features, combination_features) loss += (style_weight / len(feature_layers)) * sl# The total variation lossdef total_variation_loss(x): a = backend.square(x[:, :height-1, :width-1, :] - x[:, 1:, :width-1, :]) b = backend.square(x[:, :height-1, :width-1, :] - x[:, :height-1, 1:, :]) return backend.sum(backend.pow(a + b, 1.25))loss += total_variation_weight * total_variation_loss(combination_image)# Define needed gradients and solve the optimisation problemgrads = backend.gradients(loss, combination_image)outputs = [loss]outputs += gradsf_outputs = backend.function([combination_image], outputs)def eval_loss_and_grads(x): x = x.reshape((1, height, width, 3)) outs = f_outputs([x]) loss_value = outs[0] grad_values = outs[1].flatten().astype('float64') return loss_value, grad_valuesclass Evaluator(object): def __init__(self): self.loss_value = None self.grads_values = None def loss(self, x): assert self.loss_value is None loss_value, grad_values = eval_loss_and_grads(x) self.loss_value = loss_value self.grad_values = grad_values return self.loss_value def grads(self, x): assert self.loss_value is not None grad_values = np.copy(self.grad_values) self.loss_value = None self.grad_values = None return grad_valuesevaluator = Evaluator()# Trainx = np.random.uniform(0, 255, (1, height, width, 3)) - 128.iterations = 10for i in range(iterations): print('Start of iteration', i) start_time = time.time() x, min_val, info = fmin_l_bfgs_b(evaluator.loss, x.flatten(), fprime=evaluator.grads, maxfun=20) print('Current loss value:', min_val) end_time = time.time() print('Iteration %d completed in %ds' % (i, end_time - start_time))# Evaluationx = x.reshape((height, width, 3))x = x[:, :, ::-1]x[:, :, 0] += 103.939x[:, :, 1] += 116.779x[:, :, 2] += 123.68x = np.clip(x, 0, 255).astype('uint8')Image.fromarray(x)
网络转换
结构
将风格转换当成优化问题求解存在如下问题:
- 每来一张新图片,都需要重新求解优化问题。如果需要将大量图片转换成同一风格的话效率会很低
考虑能否构建一个transformer
,将图片transformer
的参数。训练完成得到transformer
后,当新的图片来到时,直接输入transformer
即可得到新的图片,大大提高了效率。
本节中的风格转换即采用上述构建transformer
的方法,利用预训练的VGG得到特征进而得到损失函数,通过调节transformer
的参数最小化损失函数。图示如下:
训练
损失函数的定义与优化问题部分相同,这里求解的优化问题是:
参考
- A Neural Algorithm of Artistic Style
- Perceptual Losses for Real-Time Style Transfer
and Super-Resolution
- 风格转换简介
- 窗口风格参数简介
- REST架构风格简介
- REST架构风格简介
- REST架构风格简介
- RESTful风格WebService简介
- web站点风格转换
- 实现风格转换页面
- 自动语义风格转换
- 注释转换(c风格转为c++风格)
- [编程风格要素] I 简介
- 注视转换 将C风格注释转换为C++风格
- C++风格的强制转换
- c++风格的类型转换
- C++风格的类型转换
- C++风格的强制转换
- 谷歌风格:强制性转换
- c++风格的类型转换
- 世先有良医然后才有良药
- shell_编程2(语法)
- 每天学一点Swift----面向对象上(三)
- compass 中的 页脚固定 sticky-footer
- UnpooledHeadByteBuf源码分析
- 风格转换简介
- 计算机网络之TCP实验(wireshark版)
- Linux内建命令
- 侯捷-《STL源码剖析》的一些可能的错误
- Webpack基础之输出
- C++格式化输出,C++输出格式控制
- codevs 1643 线段覆盖 3(贪心+快排)
- 探究线性表与链表
- C++中虚函数和纯虚函数的区别