python实现自编码器autoencode

来源:互联网 发布:昆山宏观数据库 编辑:程序博客网 时间:2024/06/05 22:35
# -*- coding: utf-8 -*-"""Created on Sun Sep  3 13:48:19 2017@author: piaodexin"""
from __future__ import division, print_function, absolute_importimport tensorflow as tffrom  tensorflow.examples.tutorials.mnist import input_dataimport matplotlib.pyplot as pltimport numpy as npmnist=input_data.read_data_sets('E:\\mnist',one_hot=True)'''定义输入层 (28,28) =784第一层隐含层500个第二层100个第三层500输出层784 这是因为自编码就是希望神经网络自己学习图片特征,然后再用学习到的特征去组成原始图片,所以最后输出层是(28,28)=784'''input_n=784hidden1_n=500hidden2_n=100hidden3_n=500output_n=784learn_rate=0.01batch_size=100train_epoch=30000x=tf.placeholder(tf.float32,[None,input_n])y=tf.placeholder(tf.float32,[None,input_n])weights1=tf.Variable(tf.truncated_normal([input_n,hidden1_n],stddev=0.1))bias1=tf.Variable(tf.constant(0.1,shape=[hidden1_n]))weights2=tf.Variable(tf.truncated_normal([hidden1_n,hidden2_n],stddev=0.1))bias2=tf.Variable(tf.constant(0.1,shape=[hidden2_n]))weights3=tf.Variable(tf.truncated_normal([hidden2_n,hidden3_n],stddev=0.1))bias3=tf.Variable(tf.constant(0.1,shape=[hidden3_n]))weights4=tf.Variable(tf.truncated_normal([hidden3_n,output_n],stddev=0.1))bias4=tf.Variable(tf.constant(0.1,shape=[output_n]))def get_result(x,weights1,bias1,weights2,bias2,weights3,bias3,weights4,bias4):    a1=tf.nn.sigmoid(tf.matmul(x,weights1)+bias1)    a2=tf.nn.sigmoid(tf.matmul(a1,weights2)+bias2)    a3=tf.nn.sigmoid(tf.matmul(a2,weights3)+bias3)    y_=tf.nn.sigmoid(tf.matmul(a3,weights4)+bias4)    return y_'''当我一步一步求y_的时候,却出现错误,只能用函数,不知道为什么'''y_=get_result(x,weights1,bias1,weights2,bias2,weights3,bias3,weights4,bias4)loss=tf.reduce_mean(tf.pow(y_-y,2))train_op=tf.train.RMSPropOptimizer(learn_rate).minimize(loss)with tf.Session() as sess:    tf.global_variables_initializer().run()    for i in range(train_epoch):        xs,ys=mnist.train.next_batch(batch_size)        if i%1000 == 0:            print('epoch:',i)            print('loss:',sess.run(loss,feed_dict={x:xs,y:xs}))        sess.run(train_op,feed_dict={x:xs,y:xs})    xt=mnist.test.images[:5]    yt=xt     encode_decode=sess.run(y_,feed_dict={x:xt,y:yt})    f,a =plt.subplots(2,5,figsize=(10,2))    for i in range(5):        a[0][i].imshow(np.reshape(mnist.test.images[i],(28,28)))        a[1][i].imshow(np.reshape(encode_decode[i],(28,28)))    f.show()    plt.draw()
#结果展示:上面是原图片,下面是自编码学习到的