Speech Enhancement Generation Adversarial Network(code analysis)
来源:互联网 发布:linux中开机自动启动 编辑:程序博客网 时间:2024/06/10 15:33
本文接上一篇文章,主要讲述网络代码结构(main_test)
1、定义flags,分别为整形、布尔型、浮点、字符串
flags.DEFINE_integerflags.DEFINE_booleanflags.DEFINE_floatflags.DEFINE_string
2、选择CPU/GPU
for device in devices: if len(devices) > 1 and 'cpu' in device.name: # Use cpu only when we dont have gpus continue print('Using device: ', device.name) udevices.append(device.name)
3、选择模型(GAN/AE)以及训练/测试
with tf.Session(config=config) as sess: if FLAGS.model == 'gan': print('Creating GAN model') se_model = SEGAN(sess, FLAGS, udevices) elif FLAGS.model == 'ae': print('Creating AE model') se_model = SEAE(sess, FLAGS, udevices) else: raise ValueError('{} model type not understood!'.format(FLAGS.model)) if FLAGS.test_wav is None: se_model.train(FLAGS, udevices) else: if FLAGS.weights is None: raise ValueError('weights must be specified!') print('Loading model weights...')
4、对单通道带噪音频进行clean处理
se_model.load(FLAGS.save_path, FLAGS.weights)wav_all = os.listdir(FLAGS.test_wav)wav_number = len(wav_all)for wav_index in range(wav_number): wav_single_full = FLAGS.test_wav+wav_all[wav_index] fm, wav_data = wavfile.read(wav_single_full) wavname = wav_all[wav_index] if fm != 16000: raise ValueError('16kHz required! Test file is different') wave1 = (2./65535.) * (wav_data.astype(np.float32) - 32767) + 1. if FLAGS.preemph > 0: print('preemph test wave with {}'.format(FLAGS.preemph)) x_pholder_1, preemph_op_1 = pre_emph_test(FLAGS.preemph, wave1.shape[0]) wave1 = sess.run(preemph_op_1, feed_dict={x_pholder_1:wave1}) print('test wave1 shape: ', wave1.shape) print('test wave1 min:{} max:{}'.format(np.min(wave1), np.max(wave1))) c_wave1 = se_model.clean(wave1) wavfile.write(os.path.join(FLAGS.save_clean_path, wavname), 16000, c_wave1)
5、进入clean函数
* 对数据进行分段处理,进入Gs
* 进而选择模型(AE/GAN)
def clean(self, x): """ clean a utterance x x: numpy array containing the normalized noisy waveform """ c_res = None for beg_i in range(0, x.shape[0], self.canvas_size): if x.shape[0] - beg_i < self.canvas_size: length = x.shape[0] - beg_i pad = (self.canvas_size) - length else: length = self.canvas_size pad = 0 x_ = np.zeros((self.batch_size, self.canvas_size)) if pad > 0: x_[0] = np.concatenate((x[beg_i:beg_i + length], np.zeros(pad))) else: x_[0] = x[beg_i:beg_i + length] print('Cleaning chunk {} -> {}'.format(beg_i, beg_i + length)) fdict = {self.gtruth_noisy[0]:x_} canvas_w = self.sess.run(self.Gs[0], feed_dict=fdict)[0]** canvas_w = canvas_w.reshape((self.canvas_size)) print('canvas w shape: ', canvas_w.shape) if pad > 0: print('Removing padding of {} samples'.format(pad)) # get rid of last padded samples canvas_w = canvas_w[:-pad] if c_res is None: c_res = canvas_w else: c_res = np.concatenate((c_res, canvas_w)) # deemphasize c_res = de_emph(c_res, self.preemph) return c_res
6、AEGenerator
6.1 make z
def make_z(shape, mean=0., std=1., name='z'): if is_ref: with tf.variable_scope(name) as scope: z_init = tf.random_normal_initializer(mean=mean, stddev=std) z = tf.get_variable("z", shape, initializer=z_init, trainable=False ) if z.device != "/device:GPU:0": # this has to be created into gpu0 print('z.device is {}'.format(z.device)) assert False else: z = tf.random_normal(shape, mean=mean, stddev=std, name=name, dtype=tf.float32) return z
6.2 First encoder
* 在for循环内循环读入h_i,采用downconv函数提取特征
* 此过程还存入了skips,为输入second coder做准备
* 如果z开关打开,即拼接z和h_i,作为second coder的输入
for layer_idx, layer_depth in enumerate(segan.g_enc_depths): bias_init = None if segan.bias_downconv: if is_ref: print('Biasing downconv in G') bias_init = tf.constant_initializer(0.) h_i_dwn = downconv(h_i, layer_depth, kwidth=kwidth, init=tf.truncated_normal_initializer(stddev=0.02), bias_init=bias_init, name='enc_{}'.format(layer_idx)) if is_ref: print('Downconv {} -> {}'.format(h_i.get_shape(), h_i_dwn.get_shape())) h_i = h_i_dwn if layer_idx < len(segan.g_enc_depths) - 1: if is_ref: print('Adding skip connection downconv ' '{}'.format(layer_idx)) # store skip connection # last one is not stored cause it's the code skips.append(h_i) if do_prelu: if is_ref: print('-- Enc: prelu activation --') h_i = prelu(h_i, ref=is_ref, name='enc_prelu_{}'.format(layer_idx)) if is_ref: # split h_i into its components alpha_i = h_i[1] h_i = h_i[0] alphas.append(alpha_i) else: if is_ref: print('-- Enc: leakyrelu activation --') h_i = leakyrelu(h_i)if z_on: # random code is fused with intermediate representation z = make_z([segan.batch_size, h_i.get_shape().as_list()[1], segan.g_enc_depths[-1]]) h_i = tf.concat([z, h_i],2)
6.3 Second encoder
* 选择deconv或者nn_deconv
* 激活函数prelu
* 读取skip,使用concat拼接
g_dec_depths = segan.g_enc_depths[:-1][::-1] + [1] if is_ref: print('g_dec_depths: ', g_dec_depths) for layer_idx, layer_depth in enumerate(g_dec_depths): h_i_dim = h_i.get_shape().as_list() out_shape = [h_i_dim[0], h_i_dim[1] * 2, layer_depth] bias_init = None # deconv if segan.deconv_type == 'deconv': if is_ref: print('-- Transposed deconvolution type --') if segan.bias_deconv: print('Biasing deconv in G') if segan.bias_deconv: bias_init = tf.constant_initializer(0.) h_i_dcv = deconv(h_i, out_shape, kwidth=kwidth, dilation=2, init=tf.truncated_normal_initializer(stddev=0.02), bias_init=bias_init, name='dec_{}'.format(layer_idx)) elif segan.deconv_type == 'nn_deconv': if is_ref: print('-- NN interpolated deconvolution type --') if segan.bias_deconv: print('Biasing deconv in G') if segan.bias_deconv: bias_init = 0. h_i_dcv = nn_deconv(h_i, kwidth=kwidth, dilation=2, init=tf.truncated_normal_initializer(stddev=0.02), bias_init=bias_init, name='dec_{}'.format(layer_idx)) else: raise ValueError('Unknown deconv type {}'.format(segan.deconv_type)) if is_ref: print('Deconv {} -> {}'.format(h_i.get_shape(), h_i_dcv.get_shape())) h_i = h_i_dcv if layer_idx < len(g_dec_depths) - 1: if do_prelu: if is_ref: print('-- Dec: prelu activation --') h_i = prelu(h_i, ref=is_ref, name='dec_prelu_{}'.format(layer_idx)) if is_ref: # split h_i into its components alpha_i = h_i[1] h_i = h_i[0] alphas.append(alpha_i) else: if is_ref: print('-- Dec: leakyrelu activation --') h_i = leakyrelu(h_i) # fuse skip connection skip_ = skips[-(layer_idx + 1)] if is_ref: print('Fusing skip connection of ' 'shape {}'.format(skip_.get_shape())) h_i = tf.concat([h_i, skip_],2) else: if is_ref: print('-- Dec: tanh activation --') h_i = tf.tanh(h_i)
7、Generator模型
* 首先make_z
* residual_block详见8 return res_i
* skips→stack→leakyrelu→con1d→tanh→wave
def make_z(shape, mean=0., std=1., name='z'): if is_ref: with tf.variable_scope(name) as scope: z_init = tf.random_normal_initializer(mean=mean, stddev=std) z = tf.get_variable("z", shape, initializer=z_init, trainable=False ) if z.device != "/device:GPU:0": # this has to be created into gpu0 print('z.device is {}'.format(z.device)) assert False else: z = tf.random_normal(shape, mean=mean, stddev=std, name=name, dtype=tf.float32) return z if hasattr(segan, 'generator_built'): tf.get_variable_scope().reuse_variables() make_vars = False else: make_vars = True print('*** Building Generator ***') in_dims = noisy_w.get_shape().as_list() h_i = noisy_w if len(in_dims) == 2: h_i = tf.expand_dims(noisy_w, -1) elif len(in_dims) < 2 or len(in_dims) > 3: raise ValueError('Generator input must be 2-D or 3-D') kwidth = 3 z = make_z([segan.batch_size, h_i.get_shape().as_list()[1], segan.g_enc_depths[-1]]) h_i = tf.concat([h_i, z],2) skip_out = True skips = [] for block_idx, dilation in enumerate(segan.g_dilated_blocks): name = 'g_residual_block_{}'.format(block_idx) if block_idx >= len(segan.g_dilated_blocks) - 1: skip_out = False if skip_out: res_i, skip_i = residual_block(h_i, dilation, kwidth, num_kernels=32, bias_init=None, stddev=0.02, do_skip = True, name=name) else: res_i = residual_block(h_i, dilation, kwidth, num_kernels=32, bias_init=None, stddev=0.02, do_skip = False, name=name) # feed the residual output to the next block h_i = res_i if segan.keep_prob < 1: print('Adding dropout w/ keep prob {} ' 'to G'.format(segan.keep_prob)) h_i = tf.nn.dropout(h_i, segan.keep_prob_var) if skip_out: # accumulate the skip connections skips.append(skip_i) else: # for last block, the residual output is appended skips.append(res_i) print('Amount of skip connections: ', len(skips)) # TODO: last pooling for actual wave with tf.variable_scope('g_wave_pooling'): skip_T = tf.stack(skips, axis=0) skips_sum = tf.reduce_sum(skip_T, axis=0) skips_sum = leakyrelu(skips_sum) wave_a = conv1d(skips_sum, kwidth=1, num_kernels=1, init=tf.truncated_normal_initializer(stddev=0.02)) wave = tf.tanh(wave_a) segan.gen_wave_summ = histogram_summary('gen_wave', wave) print('Last residual wave shape: ', res_i.get_shape()) print('*************************') segan.generator_built = True return wave, z
8.Residual_block(ops.py)
* input→atrous_conv1d→(h_a)tanh→h
* input→atrous conv1d→(z_a)sigmoid→z
* gated_h=z*h
* gated_h→con1d→h_
* res=h_+input
* 如果存在skip,则gated_h→con1d→skip
* return res skip
def residual_block(input_, dilation, kwidth, num_kernels=1, bias_init=None, stddev=0.02, do_skip=True, name='residual_block'): print('input shape to residual block: ', input_.get_shape()) with tf.variable_scope(name): h_a = atrous_conv1d(input_, dilation, kwidth, num_kernels, bias_init=bias_init, stddev=stddev) h = tf.tanh(h_a) # apply gated activation z_a = atrous_conv1d(input_, dilation, kwidth, num_kernels, name='conv_gate', bias_init=bias_init, stddev=stddev) z = tf.nn.sigmoid(z_a) print('gate shape: ', z.get_shape()) # element-wise apply the gate gated_h = tf.multiply(z, h) print('gated h shape: ', gated_h.get_shape()) #make res connection h_ = conv1d(gated_h, kwidth=1, num_kernels=1, init=tf.truncated_normal_initializer(stddev=stddev), name='residual_conv1') res = h_ + input_ print('residual result: ', res.get_shape()) if do_skip: #make skip connection skip = conv1d(gated_h, kwidth=1, num_kernels=1, init=tf.truncated_normal_initializer(stddev=stddev), name='skip_conv1') return res, skip else: return res
阅读全文
0 0
- Speech Enhancement Generation Adversarial Network(code analysis)
- Speech Enhancement Generation Adversarial Network
- Speech Enhancement
- 对抗生成网络(Generative Adversarial Network)
- JEECG(J2EE Code Generation)
- Android code generation(1)
- cglib(Code Generation Library)
- generative adversarial network
- GAN: Generative Adversarial Network
- Weighted Network Analysis (WGCNA)
- 代码增强(Code Enhancement)
- Generative Adversarial Network (GAN) papers (不定期更新)
- Adversarial Learning for Neural Dialogue Generation
- EMF中的code generation(一)
- EMF中的code generation(二)
- Impala学习--代码生成(Code Generation)
- Impala学习--代码生成(Code Generation)
- 代码生成(Code Generation with Bake)
- 如何避免重复宏定义?重定义错误
- zpool import 找不到存储池
- 反省
- MYSQL的存储过程、变量、函数及部分简单优化(非常有用,有存储过程+光标实例)
- 关于PIP安装软件包报error code 1 错误的原因与对策
- Speech Enhancement Generation Adversarial Network(code analysis)
- response文件流下载
- AWS+SS 实现外网加速
- Linux虚拟机安装一些基本
- SparkSQL读取HBase数据,通过自定义外部数据源(hbase的Hive外关联表)
- Mac Mysql 重置密码
- IdeaIU2017下载破解
- Sublime Text 3 、WebStorm配置护眼主题(浅绿色)
- nginx的location配置总结