tensorflow1.1/variational_autoencoder
来源:互联网 发布:手机淘宝体检中心登录 编辑:程序博客网 时间:2024/06/05 19:25
环境tensorflow1.1,matplotlib2.02,python3
近年,非监督学习成为了研究热点。VAE(Variational Auto-Encoder,变分自编码器)和 GAN(Generative Adversarial Networks) 等模型,受到越来越多的关注
VAE:模型结构:
其中:loss = mse+KLDivergence
#coding:utf-8"""tensorflow 1.1matplotlib 2.02"""import tensorflow as tfimport numpy as npimport input_dataimport matplotlib.pyplot as pltinput_dim = 784hidden_encoder_dim = 1200hidden_decoder_dim = 1200latent_dim = 200epochs = 3000batch_size = 100N_pictures=3mnist = input_data.read_data_sets('mnist/')def weight_variable(shape): #tf.truncated_normal()截断的标准正态分布 return tf.Variable(tf.truncated_normal(shape,stddev=0.001))def bias_variable(shape): return tf.Variable(tf.truncated_normal(shape))x = tf.placeholder('float32',[None,input_dim])#在全连接层加入l2_regularizationl2_loss = tf.constant(0.0)#encoder网络w_encoder1 =weight_variable([input_dim,hidden_encoder_dim])b_encoder1 = bias_variable([hidden_encoder_dim])encoder1 = tf.nn.relu(tf.matmul(x,w_encoder1)+b_encoder1)#第一层的l2_lossl2_loss += tf.nn.l2_loss(w_encoder1)#定义一个mu网络mu_w_encoder2 = weight_variable([hidden_encoder_dim,latent_dim])mu_b_encoder2 = bias_variable([latent_dim])mu_encoder2 = tf.matmul(encoder1,mu_w_encoder2)+mu_b_encoder2#mu网络的l2_lossl2_loss += tf.nn.l2_loss(mu_w_encoder2)#定义一个var网络var_w_encoder2 = weight_variable([hidden_encoder_dim,latent_dim])var_b_encoder2 = bias_variable([latent_dim])var_encoder2 = tf.matmul(encoder1,var_w_encoder2)+var_b_encoder2#var网络的l2_lossl2_loss += tf.nn.l2_loss(var_w_encoder2)#抽样#生成标准正态分布epsilon = tf.random_normal(tf.shape(var_encoder2))new_var_encoder2 = tf.sqrt(tf.exp(var_encoder2))#z的维度是latent_dimz = mu_encoder2+tf.multiply(new_var_encoder2,epsilon)#定义decoder网络w_decoder1 = weight_variable([latent_dim,hidden_decoder_dim])b_decoder1 = bias_variable([hidden_decoder_dim])decoder1 = tf.nn.relu(tf.matmul(z,w_decoder1)+b_decoder1)l2_loss += tf.nn.l2_loss(w_decoder1)w_decoder2 = weight_variable([hidden_decoder_dim,input_dim])b_decoder2 = bias_variable([input_dim])#输出层没有使用激活函数(加入激活函数后面用log_px_given_z,不加入激活函数用cost1)decoder2 = tf.nn.sigmoid(tf.matmul(decoder1,w_decoder2)+b_decoder2)l2_loss += tf.nn.l2_loss(w_decoder2)#计算costlog_px_given_z = -tf.reduce_sum(x*tf.log(decoder2+1e-10)+(1-x)*tf.log(1-decoder2+1e-10),1)#cost1 = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=decoder2,labels=x),reduction_indices=1)#计算KL DivergenceKLD = -0.5*tf.reduce_sum(1+var_encoder2-tf.pow(mu_encoder2,2)-tf.exp(var_encoder2),reduction_indices=1)cost = tf.reduce_mean(log_px_given_z+KLD)#加上regularization regularized_cost = cost + l2_losstrain = tf.train.AdamOptimizer(0.01).minimize(cost)with tf.Session() as sess: init = tf.global_variables_initializer() sess.run(init) #画图,2行5列返回图和子图 figure_,a = plt.subplots(2,N_pictures,figsize=(6,4)) #开始交互模式 plt.ion() #测试的图 view_figures = mnist.test.images[:N_pictures] for i in range(N_pictures): #将图片reshape为28行28列显示 a[0][i].imshow(np.reshape(view_figures[i],(28,28))) #清空x轴,y轴坐标 a[0][i].set_xticks(()) a[0][i].set_yticks(()) for step in range(10000): batch_x,batch_y = mnist.train.next_batch(batch_size) #encoder3和decoder3需要进行run _,encoded,decoded,c = sess.run([train,z,decoder2,cost],feed_dict={x:batch_x}) if step % 1000 ==0: print('= = = = = = > > > > > >','train loss:% .4f' % c) #将真实的图片和autoencoder后的图片对比 decoder_figures = sess.run(decoder2,feed_dict={x:view_figures}) for i in range(N_pictures): #清除第一行图片 a[1][i].clear() a[1][i].imshow(np.reshape(decoder_figures[i],(28,28))) a[1][i].set_xticks(()) a[1][i].set_yticks(()) plt.draw() plt.pause(1) plt.ioff() #关闭交互模式"""with tf.Session() as sess: init = tf.global_variables_initializer() sess.run(init) for epoch in range(epochs): batch_x,batch_y = mnist.train.next_batch(batch_size) _,c = sess.run([train,cost],feed_dict={x:batch_x}) if epoch % 100 == 0: print('- - - - - - > > > > > > epoch: ',int(epoch/100),'cost: %.4f' %c) #输出结果可视化 encoder_result = sess.run(z,feed_dict={x:mnist.test.images}) plt.scatter(encoder_result[:,0],encoder_result[:,1],c = mnist.test.labels,label='mnist distributions') plt.legend(loc='best') plt.title('different mnist digits shows in figure') plt.colorbar() plt.show()"""
结果
聚类效果:
阅读全文
0 0
- tensorflow1.1/variational_autoencoder
- python3/tensorflow1.1
- tensorflow1.1/线性回归
- tensorflow1.1/optimizer可视化
- tensorflow1.1/tensorboard可视化
- tensorflow1.1/autoencoder2
- tensorflow1.1/embedding可视化
- tensorflow1.1/RNN预测
- win7 + tensorflow1.1+k40c , win10+tensorflow1.4+1080ti
- tensorflow1.1/激活函数可视化
- tensorflow1.1/构建神经网络分类
- tensorflow1.1及python3安装
- 1、Tensorflow:Windows10+tensorflow1.4
- tensorflow1.1 和python3.5.3安装
- 从TensorFlow0.12升级到TensorFlow1.1
- tensorflow1.1的几个api变动
- TensorFlow1.1搭建自编码网络
- tensorflow1.1/非监督学习autoencoder1
- select 语句查询
- 安卓——WIFI连接
- Android 把bitmap图片的某一部分的颜色改成其他颜色
- 算法总览
- 9列表
- tensorflow1.1/variational_autoencoder
- 11.2 rac psu 补丁
- 使用source命令向MySQL导入超大文件
- Ubuntu下删除两个文件夹下相同文件名且相同内容的文件(分色排版)
- sass的具体使用方法
- 11表格元素下
- 在前端的一些注意的问题
- echarts入门教程(含小案例)
- Texas Trip