Speech Enhancement Generation Adversarial Network(code analysis)

来源:互联网 发布:linux中开机自动启动 编辑:程序博客网 时间:2024/06/10 15:33

本文接上一篇文章,主要讲述网络代码结构(main_test)


1、定义flags,分别为整形、布尔型、浮点、字符串

flags.DEFINE_integerflags.DEFINE_booleanflags.DEFINE_floatflags.DEFINE_string

2、选择CPU/GPU

for device in devices:    if len(devices) > 1 and 'cpu' in device.name:        # Use cpu only when we dont have gpus        continue    print('Using device: ', device.name)    udevices.append(device.name)

3、选择模型(GAN/AE)以及训练/测试

with tf.Session(config=config) as sess:    if FLAGS.model == 'gan':        print('Creating GAN model')        se_model = SEGAN(sess, FLAGS, udevices)    elif FLAGS.model == 'ae':        print('Creating AE model')        se_model = SEAE(sess, FLAGS, udevices)    else:        raise ValueError('{} model type not understood!'.format(FLAGS.model))    if FLAGS.test_wav is None:        se_model.train(FLAGS, udevices)    else:        if FLAGS.weights is None:            raise ValueError('weights must be specified!')        print('Loading model weights...')

4、对单通道带噪音频进行clean处理

se_model.load(FLAGS.save_path, FLAGS.weights)wav_all = os.listdir(FLAGS.test_wav)wav_number = len(wav_all)for wav_index in range(wav_number):    wav_single_full = FLAGS.test_wav+wav_all[wav_index]    fm, wav_data = wavfile.read(wav_single_full)    wavname = wav_all[wav_index]    if fm != 16000:        raise ValueError('16kHz required! Test file is different')    wave1 = (2./65535.) * (wav_data.astype(np.float32) - 32767) + 1.    if FLAGS.preemph  > 0:        print('preemph test wave with {}'.format(FLAGS.preemph))        x_pholder_1, preemph_op_1 = pre_emph_test(FLAGS.preemph, wave1.shape[0])        wave1 = sess.run(preemph_op_1, feed_dict={x_pholder_1:wave1})    print('test wave1 shape: ', wave1.shape)    print('test wave1 min:{}  max:{}'.format(np.min(wave1), np.max(wave1)))    c_wave1 = se_model.clean(wave1)    wavfile.write(os.path.join(FLAGS.save_clean_path, wavname), 16000, c_wave1)

5、进入clean函数
* 对数据进行分段处理,进入Gs
* 进而选择模型(AE/GAN)

def clean(self, x): """ clean a utterance x     x: numpy array containing the normalized noisy waveform """ c_res = None for beg_i in range(0, x.shape[0], self.canvas_size):     if x.shape[0] - beg_i  < self.canvas_size:         length = x.shape[0] - beg_i         pad = (self.canvas_size) - length     else:         length = self.canvas_size         pad = 0     x_ = np.zeros((self.batch_size, self.canvas_size))     if pad > 0:         x_[0] = np.concatenate((x[beg_i:beg_i + length], np.zeros(pad)))     else:         x_[0] = x[beg_i:beg_i + length]     print('Cleaning chunk {} -> {}'.format(beg_i, beg_i + length))     fdict = {self.gtruth_noisy[0]:x_}     canvas_w = self.sess.run(self.Gs[0],                              feed_dict=fdict)[0]**     canvas_w = canvas_w.reshape((self.canvas_size))     print('canvas w shape: ', canvas_w.shape)     if pad > 0:         print('Removing padding of {} samples'.format(pad))         # get rid of last padded samples         canvas_w = canvas_w[:-pad]     if c_res is None:         c_res = canvas_w     else:         c_res = np.concatenate((c_res, canvas_w)) # deemphasize c_res = de_emph(c_res, self.preemph) return c_res

6、AEGenerator
6.1 make z

 def make_z(shape, mean=0., std=1., name='z'):            if is_ref:                with tf.variable_scope(name) as scope:                    z_init = tf.random_normal_initializer(mean=mean, stddev=std)                    z = tf.get_variable("z", shape,                                        initializer=z_init,                                        trainable=False                                        )                    if z.device != "/device:GPU:0":                        # this has to be created into gpu0                        print('z.device is {}'.format(z.device))                        assert False            else:                z = tf.random_normal(shape, mean=mean, stddev=std,                                     name=name, dtype=tf.float32)            return z

6.2 First encoder
* 在for循环内循环读入h_i,采用downconv函数提取特征
* 此过程还存入了skips,为输入second coder做准备
* 如果z开关打开,即拼接z和h_i,作为second coder的输入

for layer_idx, layer_depth in enumerate(segan.g_enc_depths):    bias_init = None    if segan.bias_downconv:        if is_ref:            print('Biasing downconv in G')        bias_init = tf.constant_initializer(0.)    h_i_dwn = downconv(h_i, layer_depth, kwidth=kwidth,                       init=tf.truncated_normal_initializer(stddev=0.02),                       bias_init=bias_init,                       name='enc_{}'.format(layer_idx))    if is_ref:        print('Downconv {} -> {}'.format(h_i.get_shape(),                                         h_i_dwn.get_shape()))    h_i = h_i_dwn    if layer_idx < len(segan.g_enc_depths) - 1:        if is_ref:            print('Adding skip connection downconv '                  '{}'.format(layer_idx))        # store skip connection        # last one is not stored cause it's the code        skips.append(h_i)    if do_prelu:        if is_ref:            print('-- Enc: prelu activation --')        h_i = prelu(h_i, ref=is_ref, name='enc_prelu_{}'.format(layer_idx))        if is_ref:            # split h_i into its components            alpha_i = h_i[1]            h_i = h_i[0]            alphas.append(alpha_i)    else:        if is_ref:            print('-- Enc: leakyrelu activation --')        h_i = leakyrelu(h_i)if z_on:    # random code is fused with intermediate representation    z = make_z([segan.batch_size, h_i.get_shape().as_list()[1],                segan.g_enc_depths[-1]])    h_i = tf.concat([z, h_i],2)

6.3 Second encoder
* 选择deconv或者nn_deconv
* 激活函数prelu
* 读取skip,使用concat拼接

g_dec_depths = segan.g_enc_depths[:-1][::-1] + [1]            if is_ref:                print('g_dec_depths: ', g_dec_depths)            for layer_idx, layer_depth in enumerate(g_dec_depths):                h_i_dim = h_i.get_shape().as_list()                out_shape = [h_i_dim[0], h_i_dim[1] * 2, layer_depth]                bias_init = None                # deconv                if segan.deconv_type == 'deconv':                    if is_ref:                        print('-- Transposed deconvolution type --')                        if segan.bias_deconv:                            print('Biasing deconv in G')                    if segan.bias_deconv:                        bias_init = tf.constant_initializer(0.)                    h_i_dcv = deconv(h_i, out_shape, kwidth=kwidth, dilation=2,                                     init=tf.truncated_normal_initializer(stddev=0.02),                                     bias_init=bias_init,                                     name='dec_{}'.format(layer_idx))                elif segan.deconv_type == 'nn_deconv':                    if is_ref:                        print('-- NN interpolated deconvolution type --')                        if segan.bias_deconv:                            print('Biasing deconv in G')                    if segan.bias_deconv:                        bias_init = 0.                    h_i_dcv = nn_deconv(h_i, kwidth=kwidth, dilation=2,                                        init=tf.truncated_normal_initializer(stddev=0.02),                                        bias_init=bias_init,                                        name='dec_{}'.format(layer_idx))                else:                    raise ValueError('Unknown deconv type {}'.format(segan.deconv_type))                if is_ref:                    print('Deconv {} -> {}'.format(h_i.get_shape(),                                                   h_i_dcv.get_shape()))                h_i = h_i_dcv                if layer_idx < len(g_dec_depths) - 1:                    if do_prelu:                        if is_ref:                            print('-- Dec: prelu activation --')                        h_i = prelu(h_i, ref=is_ref,                                    name='dec_prelu_{}'.format(layer_idx))                        if is_ref:                            # split h_i into its components                            alpha_i = h_i[1]                            h_i = h_i[0]                            alphas.append(alpha_i)                    else:                        if is_ref:                            print('-- Dec: leakyrelu activation --')                        h_i = leakyrelu(h_i)                    # fuse skip connection                    skip_ = skips[-(layer_idx + 1)]                    if is_ref:                        print('Fusing skip connection of '                              'shape {}'.format(skip_.get_shape()))                    h_i = tf.concat([h_i, skip_],2)                else:                    if is_ref:                        print('-- Dec: tanh activation --')                    h_i = tf.tanh(h_i)

7、Generator模型
* 首先make_z
* residual_block详见8 return res_i
* skips→stack→leakyrelu→con1d→tanh→wave

 def make_z(shape, mean=0., std=1., name='z'):     if is_ref:         with tf.variable_scope(name) as scope:             z_init = tf.random_normal_initializer(mean=mean, stddev=std)             z = tf.get_variable("z", shape,                                 initializer=z_init,                                 trainable=False                                 )             if z.device != "/device:GPU:0":                 # this has to be created into gpu0                 print('z.device is {}'.format(z.device))                 assert False     else:         z = tf.random_normal(shape, mean=mean, stddev=std,                              name=name, dtype=tf.float32)     return z if hasattr(segan, 'generator_built'):     tf.get_variable_scope().reuse_variables()     make_vars = False else:     make_vars = True print('*** Building Generator ***') in_dims = noisy_w.get_shape().as_list() h_i = noisy_w if len(in_dims) == 2:     h_i = tf.expand_dims(noisy_w, -1) elif len(in_dims) < 2 or len(in_dims) > 3:     raise ValueError('Generator input must be 2-D or 3-D') kwidth = 3 z = make_z([segan.batch_size, h_i.get_shape().as_list()[1],             segan.g_enc_depths[-1]]) h_i = tf.concat([h_i, z],2) skip_out = True skips = [] for block_idx, dilation in enumerate(segan.g_dilated_blocks):         name = 'g_residual_block_{}'.format(block_idx)         if block_idx >= len(segan.g_dilated_blocks) - 1:             skip_out = False         if skip_out:             res_i, skip_i = residual_block(h_i,                                            dilation, kwidth, num_kernels=32,                                            bias_init=None, stddev=0.02,                                            do_skip = True,                                            name=name)         else:             res_i = residual_block(h_i,                                    dilation, kwidth, num_kernels=32,                                    bias_init=None, stddev=0.02,                                    do_skip = False,                                    name=name)         # feed the residual output to the next block         h_i = res_i         if segan.keep_prob < 1:             print('Adding dropout w/ keep prob {} '                   'to G'.format(segan.keep_prob))             h_i = tf.nn.dropout(h_i, segan.keep_prob_var)         if skip_out:             # accumulate the skip connections             skips.append(skip_i)         else:             # for last block, the residual output is appended             skips.append(res_i) print('Amount of skip connections: ', len(skips)) # TODO: last pooling for actual wave with tf.variable_scope('g_wave_pooling'):     skip_T = tf.stack(skips, axis=0)     skips_sum = tf.reduce_sum(skip_T, axis=0)     skips_sum = leakyrelu(skips_sum)     wave_a = conv1d(skips_sum, kwidth=1, num_kernels=1,                     init=tf.truncated_normal_initializer(stddev=0.02))     wave = tf.tanh(wave_a)     segan.gen_wave_summ = histogram_summary('gen_wave', wave) print('Last residual wave shape: ', res_i.get_shape()) print('*************************') segan.generator_built = True return wave, z

8.Residual_block(ops.py)
* input→atrous_conv1d→(h_a)tanh→h
* input→atrous conv1d→(z_a)sigmoid→z
* gated_h=z*h
* gated_h→con1d→h_
* res=h_+input
* 如果存在skip,则gated_h→con1d→skip
* return res skip

def residual_block(input_, dilation, kwidth, num_kernels=1,                   bias_init=None, stddev=0.02, do_skip=True,                   name='residual_block'):    print('input shape to residual block: ', input_.get_shape())    with tf.variable_scope(name):        h_a = atrous_conv1d(input_, dilation, kwidth, num_kernels,                            bias_init=bias_init, stddev=stddev)        h = tf.tanh(h_a)        # apply gated activation        z_a = atrous_conv1d(input_, dilation, kwidth, num_kernels,                            name='conv_gate', bias_init=bias_init,                            stddev=stddev)        z = tf.nn.sigmoid(z_a)        print('gate shape: ', z.get_shape())        # element-wise apply the gate        gated_h = tf.multiply(z, h)        print('gated h shape: ', gated_h.get_shape())        #make res connection        h_ = conv1d(gated_h, kwidth=1, num_kernels=1,                    init=tf.truncated_normal_initializer(stddev=stddev),                    name='residual_conv1')        res = h_ + input_        print('residual result: ', res.get_shape())        if do_skip:            #make skip connection            skip = conv1d(gated_h, kwidth=1, num_kernels=1,                          init=tf.truncated_normal_initializer(stddev=stddev),                          name='skip_conv1')            return res, skip        else:            return res
原创粉丝点击