Deep Bilateral Learning for Real-Time Image Enhancement

来源:互联网 发布:sqlserver rds客户端 编辑:程序博客网 时间:2024/05/18 13:05

模型结构为:

这里写图片描述

low resolutioion 图像特征提取

1 low-lever features

如上图所示,利用nS个卷积(4层,卷积核为3×3,stride=2),从low-resolution图像中提取低层特征Si:,公式如下:
这里写图片描述

式中,I=1,...,nS为每个卷积层的索引,c,c为为卷积层的channels的索引.w为卷积核权重矩阵.bi为bias.激活函数σ采用ReLU,卷积时采用zero-padding.

2 Local features path

低层特征Si输入一个nL=2层卷积层得到局部特征Li.nS+nL对于语义特征的获取很关键,如果要获得一个更高空间的分辨率,可以通过减小nS,增大nL实现.

3 Global features path

全局特征层有2个卷积层,stride=2,之后接3个全连接层组成,层数为nG=5.全局特征效果:
这里写图片描述

4 Fusion and linear prediction

使用一个pointtwise仿射变换,加一个ReLU激活函数,来融合全局和局部特征:
这里写图片描述

这样得到了一个16×16×64的特征矩阵,将其输入1×1的卷积层得到16×16,output channels=96:
这里写图片描述

参数设置如下:
这里写图片描述

Image features as a bilateral grid

由low resolution 图像湖提取特征为16×16×96的feature map.可以等价与grid深度为d的多通道 bilateral grid:
这里写图片描述
取d=9,这样就等价于有一个16×16×8的 bilateral grid,每个grid cell包含12个,每个还有一个3×4的仿射颜色变换矩阵.

Upsampling with a trainable slicing layer

Guidance map auxiliary network

定义g为一个pointwise非线性变换,

这里写图片描述
这里写图片描述
式中,MTc3×3的颜色变换矩阵,MTc,a,t,b,b为网络要学习的参数.

Assembling the final output

最后的输入Oc由full-resolution features和sliced feature map的仿射变换得到:

这里写图片描述

模型inference代码为:

def inference(cls, lowres_input, fullres_input, params,              is_training=False):  with tf.variable_scope('coefficients'):    bilateral_coeffs = cls._coefficients(lowres_input, params, is_training)    tf.add_to_collection('bilateral_coefficients', bilateral_coeffs)  with tf.variable_scope('guide'):    guide = cls._guide(fullres_input, params, is_training)    tf.add_to_collection('guide', guide)  with tf.variable_scope('output'):    output = cls._output(        fullres_input, guide, bilateral_coeffs)    tf.add_to_collection('output', output)  return output

每个模块代码分析

1 low-lever features Si

输入为low-res input,网络结构为n个卷积层,卷积核为3×3,stride=2,代码如下:

with tf.variable_scope('splat'):  n_ds_layers = int(np.log2(params['net_input_size']/spatial_bin))  current_layer = input_tensor  for i in range(n_ds_layers):    if i > 0:  # don't normalize first layer      use_bn = params['batch_norm']    else:      use_bn = False    current_layer = conv(current_layer, cm*(2**i)*gd, 3, stride=2,                         batch_norm=use_bn, is_training=is_training,                         scope='conv{}'.format(i+1))  splat_features = current_layer

2 local features Li

用于提取图像的局部特征,网络结构为l2个卷积层,卷积核为3×3,stride=1,第一个卷积层采用batchnorm.

with tf.variable_scope('local'):  current_layer = splat_features  current_layer = conv(current_layer, 8*cm*gd, 3,                        batch_norm=params['batch_norm'],                        is_training=is_training,                       scope='conv1')  # don't normalize before fusion  current_layer = conv(current_layer, 8*cm*gd, 3, activation_fn=None,                       use_bias=False, scope='conv2')  grid_features = current_layer

3 global features Gi

用于提取全局特征,网络结构为两个卷积层,卷积核为3×3,stride=2,卷积层之后是三个全连接层,代码如下:

with tf.variable_scope('global'):  n_global_layers = int(np.log2(spatial_bin/4))  # 4x4 at the coarsest lvl  current_layer = splat_features  for i in range(2):    current_layer = conv(current_layer, 8*cm*gd, 3, stride=2,        batch_norm=params['batch_norm'], is_training=is_training,        scope="conv{}".format(i+1))  _, lh, lw, lc = current_layer.get_shape().as_list()  current_layer = tf.reshape(current_layer, [bs, lh*lw*lc])  current_layer = fc(current_layer, 32*cm*gd,                      batch_norm=params['batch_norm'], is_training=is_training,                     scope="fc1")  current_layer = fc(current_layer, 16*cm*gd,                      batch_norm=params['batch_norm'], is_training=is_training,                     scope="fc2")  # don't normalize before fusion  current_layer = fc(current_layer, 8*cm*gd, activation_fn=None, scope="fc3")  global_features = current_layer

将local feature 和global feture相加,得到fusion feature:

with tf.name_scope('fusion'):  fusion_grid = grid_features  fusion_global = tf.reshape(global_features, [bs, 1, 1, 8*cm*gd])  fusion = tf.nn.relu(fusion_grid+fusion_global)

bilateral grid of coefficients

with tf.variable_scope('prediction'):  current_layer = fusion  current_layer = conv(current_layer, gd*cls.n_out()*cls.n_in(), 1,                              activation_fn=None, scope='conv1')  with tf.name_scope('unroll_grid'):    current_layer = tf.stack(        tf.split(current_layer, cls.n_out()*cls.n_in(), axis=3), axis=4)    current_layer = tf.stack(        tf.split(current_layer, cls.n_in(), axis=4), axis=5)  tf.add_to_collection('packed_coefficients', current_layer)

guidance map g

输入为full-res input I.

def _guide(cls, input_tensor, params, is_training):  npts = 16  # number of control points for the curve  nchans = input_tensor.get_shape().as_list()[-1]  guidemap = input_tensor  # Color space change  idtity = np.identity(nchans, dtype=np.float32) + np.random.randn(1).astype(np.float32)*1e-4  ccm = tf.get_variable('ccm', dtype=tf.float32, initializer=idtity)  with tf.name_scope('ccm'):    ccm_bias = tf.get_variable('ccm_bias', shape=[nchans,], dtype=tf.float32, initializer=tf.constant_initializer(0.0))    guidemap = tf.matmul(tf.reshape(input_tensor, [-1, nchans]), ccm)    guidemap = tf.nn.bias_add(guidemap, ccm_bias, name='ccm_bias_add')    guidemap = tf.reshape(guidemap, tf.shape(input_tensor))  # Per-channel curve  with tf.name_scope('curve'):    shifts_ = np.linspace(0, 1, npts, endpoint=False, dtype=np.float32)    shifts_ = shifts_[np.newaxis, np.newaxis, np.newaxis, :]    shifts_ = np.tile(shifts_, (1, 1, nchans, 1))    guidemap = tf.expand_dims(guidemap, 4)    shifts = tf.get_variable('shifts', dtype=tf.float32, initializer=shifts_)    slopes_ = np.zeros([1, 1, 1, nchans, npts], dtype=np.float32)    slopes_[:, :, :, :, 0] = 1.0    slopes = tf.get_variable('slopes', dtype=tf.float32, initializer=slopes_)    guidemap = tf.reduce_sum(slopes*tf.nn.relu(guidemap-shifts), reduction_indices=[4])  guidemap = tf.contrib.layers.convolution2d(      inputs=guidemap,      num_outputs=1, kernel_size=1,       weights_initializer=tf.constant_initializer(1.0/nchans),      biases_initializer=tf.constant_initializer(0),      activation_fn=None,       variables_collections={'weights':[tf.GraphKeys.WEIGHTS], 'biases':[tf.GraphKeys.BIASES]},      outputs_collections=[tf.GraphKeys.ACTIVATIONS],      scope='channel_mixing')  guidemap = tf.clip_by_value(guidemap, 0, 1)  guidemap = tf.squeeze(guidemap, squeeze_dims=[3,])  return guidemap

sliced coefficients 与 full-res output

def _output(cls, im, guide, coeffs):  with tf.device('/gpu:0'):    out = bilateral_slice_apply(coeffs, guide, im, has_offset=True, name='slice')  return out
def bilateral_slice_apply(grid, guide, input_image, has_offset=True, name=None):  """Slices into a bilateral grid using the guide map.  Args:    grid: (Tensor) [batch_size, grid_h, grid_w, depth, n_outputs]      grid to slice from.    guide: (Tensor) [batch_size, h, w ] guide map to slice along.    input_image: (Tensor) [batch_size, h, w, n_input] input data onto which to      apply the affine transform.    name: (string) name for the operation.  Returns:    sliced: (Tensor) [batch_size, h, w, n_outputs] sliced output.  """  with tf.name_scope(name):    gridshape = grid.get_shape().as_list()    if len(gridshape) == 6:      gs = tf.shape(grid)      _, _, _, _, n_out, n_in = gridshape      grid = tf.reshape(grid, tf.stack([gs[0], gs[1], gs[2], gs[3], gs[4]*gs[5]]))      # grid = tf.concat(tf.unstack(grid, None, axis=5), 4)    sliced = hdrnet_ops.bilateral_slice_apply(grid, guide, input_image, has_offset=has_offset)    return sliced
bilateral_slice_apply = _hdrnet.bilateral_slice_apply

github代码:https://github.com/mgharbi/hdrnet

下载:

git clone https://github.com/mgharbi/hdrnet

安装依赖库:

cd hdrnet

sudo pip2 install -r requirements.txt

编译:

cd hdrnetmake

测试:

cd hdrnetpy.test test

这里写图片描述

返回train.py所在目录,训练:

cd ..python train.py checkpoint/ sample_data/identity/filelist.txt 

checkpoint为模型保存目录,sample_data/identity/filelist.txt 为训练数据路径.

测试训练好的模型:

python run.py checkpoint/ input_val/   test_output/
原创粉丝点击