Triple Generative Adversarial Nets

来源:互联网 发布:wifi网络连接不可用 编辑:程序博客网 时间:2024/05/21 09:44

这里写图片描述

triple-GAN由三部分组成:

(1)分类器C

(2)生成网络G

(3)判别网络D

分类器C和生成网络G的输出都输入判别网络D,目标函数为:

这里写图片描述

上式达到均衡的条件是p(x,y)=(1α)pg(x,y)+αpc(x,y),这表明C,G趋向于与输入数据同分布.虽然满足该条件,但仍无法保证p(x,y)=pg(x,y)=pc(x,y)为全局最优解.为了解决这个问题,引入标准监督损失函数,例如交叉损失熵,到分类器C,

这里写图片描述

这样目标函数重新定义为:

这里写图片描述

算法流程如下:

这里写图片描述

代码分析:

代码参考:https://github.com/zhenxuan00/triple-gan

分类网络C:

# classifier x2y: p_c(x, y) = p(x) p_c(y | x)cla_in_x = ll.InputLayer(shape=(None, 28**2))cla_layers = [cla_in_x]cla_layers.append(ll.ReshapeLayer(cla_layers[-1], (-1,1,28,28)))cla_layers.append(convlayer(l=cla_layers[-1], bn=True, dr=0.5, ps=2, n_kerns=32, d_kerns=(5,5), pad='valid', stride=1, W=Normal(0.05), nonlinearity=ln.rectify, name='cla-1'))cla_layers.append(convlayer(l=cla_layers[-1], bn=True, dr=0, ps=1, n_kerns=64, d_kerns=(3,3), pad='same', stride=1, W=Normal(0.05), nonlinearity=ln.rectify, name='cla-2'))cla_layers.append(convlayer(l=cla_layers[-1], bn=True, dr=0.5, ps=2, n_kerns=64, d_kerns=(3,3), pad='valid', stride=1, W=Normal(0.05), nonlinearity=ln.rectify, name='cla-3'))cla_layers.append(convlayer(l=cla_layers[-1], bn=True, dr=0, ps=1, n_kerns=128, d_kerns=(3,3), pad='same', stride=1, W=Normal(0.05), nonlinearity=ln.rectify, name='cla-4'))cla_layers.append(convlayer(l=cla_layers[-1], bn=True, dr=0, ps=1, n_kerns=128, d_kerns=(3,3), pad='same', stride=1, W=Normal(0.05), nonlinearity=ln.rectify, name='cla-5'))cla_layers.append(ll.GlobalPoolLayer(cla_layers[-1]))cla_layers.append(ll.DenseLayer(cla_layers[-1], num_units=num_classes, W=lasagne.init.Normal(1e-2, 0), nonlinearity=ln.softmax, name='cla-6'))classifier = cla_layers[-1]

生成网络G:

# generator y2x: p_g(x, y) = p(y) p_g(x | y) where x = G(z, y), z follows p_g(z)gen_in_z = ll.InputLayer(shape=(None, n_z))gen_in_y = ll.InputLayer(shape=(None,))gen_layers = [gen_in_z]gen_layers.append(MLPConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-1'))gen_layers.append(ll.batch_norm(ll.DenseLayer(gen_layers[-1], num_units=500, nonlinearity=ln.softplus, name='gen-2'), name='gen-3'))gen_layers.append(MLPConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-4'))gen_layers.append(ll.batch_norm(ll.DenseLayer(gen_layers[-1], num_units=500, nonlinearity=ln.softplus, name='gen-5'), name='gen-6'))gen_layers.append(MLPConcatLayer([gen_layers[-1], gen_in_y], num_classes, name='gen-7'))gen_layers.append(nn.l2normalize(ll.DenseLayer(gen_layers[-1], num_units=28**2, nonlinearity=gen_final_non, name='gen-8')))

判别网络D:

# discriminator xy2p: test a pair of input comes from p(x, y) instead of p_c or p_gdis_in_x = ll.InputLayer(shape=(None, 28**2))dis_in_y = ll.InputLayer(shape=(None,))dis_layers = [dis_in_x]dis_layers.append(nn.GaussianNoiseLayer(dis_layers[-1], sigma=noise_D_data, name='dis-1'))dis_layers.append(MLPConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-2'))dis_layers.append(nn.DenseLayer(dis_layers[-1], num_units=1000, name='dis-3'))dis_layers.append(MLPConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-4'))dis_layers.append(nn.GaussianNoiseLayer(dis_layers[-1], sigma=noise_D, name='dis-5'))dis_layers.append(MLPConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-6'))dis_layers.append(nn.DenseLayer(dis_layers[-1], num_units=500, name='dis-7'))dis_layers.append(nn.GaussianNoiseLayer(dis_layers[-1], sigma=noise_D, name='dis-8'))dis_layers.append(MLPConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-9'))dis_layers.append(nn.DenseLayer(dis_layers[-1], num_units=250, name='dis-10'))dis_layers.append(nn.GaussianNoiseLayer(dis_layers[-1], sigma=noise_D, name='dis-11'))dis_layers.append(MLPConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-12'))dis_layers.append(nn.DenseLayer(dis_layers[-1], num_units=250, name='dis-13'))dis_layers.append(nn.GaussianNoiseLayer(dis_layers[-1], sigma=noise_D, name='dis-14'))dis_layers.append(MLPConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-15'))dis_layers.append(nn.DenseLayer(dis_layers[-1], num_units=250, name='dis-16'))dis_layers.append(nn.GaussianNoiseLayer(dis_layers[-1], sigma=noise_D, name='dis-17'))dis_layers.append(MLPConcatLayer([dis_layers[-1], dis_in_y], num_classes, name='dis-18'))dis_layers.append(nn.DenseLayer(dis_layers[-1], num_units=1, nonlinearity=ln.sigmoid, name='dis-19'))

目标函数:

'''objectives'''# outputsgen_out_x = ll.get_output(gen_layers[-1], {gen_in_y:sym_y_g, gen_in_z:sym_z_rand}, deterministic=False)cla_out_y_l = ll.get_output(cla_layers[-1], sym_x_l, deterministic=False)cla_out_y_eval = ll.get_output(cla_layers[-1], sym_x_eval, deterministic=True)cla_out_y = ll.get_output(cla_layers[-1], sym_x_u, deterministic=False)cla_out_y_d = ll.get_output(cla_layers[-1], {cla_in_x:sym_x_u_d}, deterministic=False)cla_out_y_d_hard = cla_out_y_d.argmax(axis=1)cla_out_y_g = ll.get_output(cla_layers[-1], {cla_in_x:gen_out_x}, deterministic=False)dis_out_p = ll.get_output(dis_layers[-1], {dis_in_x:T.concatenate([sym_x_l,sym_x_u_d], axis=0),dis_in_y:T.concatenate([sym_y,cla_out_y_d_hard], axis=0)}, deterministic=False)dis_out_p_g = ll.get_output(dis_layers[-1], {dis_in_x:gen_out_x,dis_in_y:sym_y_g}, deterministic=False)if objective_flag == 'integrate':    # integrate    dis_out_p_c = ll.get_output(dis_layers[-1],                                {dis_in_x:T.repeat(sym_x_u, num_classes, axis=0),                                dis_in_y:np.tile(np.arange(num_classes), batch_size_u_c)},                                deterministic=False)elif objective_flag == 'argmax':    # argmax approximation    cla_out_y_hard = cla_out_y.argmax(axis=1)    dis_out_p_c = ll.get_output(dis_layers[-1], {dis_in_x:sym_x_u,dis_in_y:cla_out_y_hard}, deterministic=False)else:    raise Exception('Unknown objective flags')image = ll.get_output(gen_layers[-1], {gen_in_y:sym_y_g, gen_in_z:sym_z_image}, deterministic=False) # for generationaccurracy_eval = (lasagne.objectives.categorical_accuracy(cla_out_y_eval, sym_y)) # for evaluationaccurracy_eval = accurracy_eval.mean()# costsbce = lasagne.objectives.binary_crossentropydis_cost_p = bce(dis_out_p, T.ones(dis_out_p.shape)).mean() # D distincts pdis_cost_p_g = bce(dis_out_p_g, T.zeros(dis_out_p_g.shape)).mean() # D distincts p_ggen_cost_p_g = bce(dis_out_p_g, T.ones(dis_out_p_g.shape)).mean() # G fools Dweight_decay_classifier = lasagne.regularization.regularize_layer_params_weighted({cla_layers[-1]:1}, lasagne.regularization.l2) # weight decaydis_cost_p_c = bce(dis_out_p_c, T.zeros(dis_out_p_c.shape)) # D distincts p_ccla_cost_p_c = bce(dis_out_p_c, T.ones(dis_out_p_c.shape)) # C fools Dif objective_flag == 'integrate':    # integrate    weight_loss_c = T.reshape(cla_cost_p_c, (-1, num_classes)) * cla_out_y    cla_cost_p_c = T.sum(weight_loss_c, axis=1).mean()    weight_loss_d = T.reshape(dis_cost_p_c, (-1, num_classes)) * cla_out_y    dis_cost_p_c = T.sum(weight_loss_d, axis=1).mean()elif objective_flag == 'argmax':    # argmax approximation    p = cla_out_y.max(axis=1)    cla_cost_p_c = (cla_cost_p_c*p).mean()    dis_cost_p_c = dis_cost_p_c.mean()cla_cost_cla = categorical_crossentropy_ssl_separated(predictions_l=cla_out_y_l, targets=sym_y, predictions_u=cla_out_y, weight_decay=weight_decay_classifier, alpha_labeled=alpha_labeled, alpha_unlabeled=sym_alpha_unlabel_entropy, alpha_average=sym_alpha_unlabel_average, alpha_decay=alpha_decay) # classification losspretrain_cla_loss = categorical_crossentropy_ssl_separated(predictions_l=cla_out_y_l, targets=sym_y, predictions_u=cla_out_y, weight_decay=weight_decay_classifier, alpha_labeled=alpha_labeled, alpha_unlabeled=pre_alpha_unlabeled_entropy, alpha_average=pre_alpha_average, alpha_decay=alpha_decay) # classification losspretrain_cost = pretrain_cla_losscla_cost_cla_g = categorical_crossentropy(predictions=cla_out_y_g, targets=sym_y_g)dis_cost = dis_cost_p + .5*dis_cost_p_g + .5*dis_cost_p_cgen_cost = .5*gen_cost_p_g# flagcla_cost = .5*cla_cost_p_c + alpha_cla*(cla_cost_cla + sym_alpha_cla_g*cla_cost_cla_g)
原创粉丝点击