autoencode

来源:互联网 发布:淘宝的kindle不能注册 编辑:程序博客网 时间:2024/05/17 02:43

1.  autoencode



from __future__ import division, print_function, absolute_importimport tensorflow as tfimport numpy as npimport matplotlib.pyplot as plt#%matplotlib inline# Import MNIST datafrom tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("/tmp/data/", one_hot=True)# Parameterslearning_rate = 0.01training_epochs = 20batch_size = 256display_step = 1examples_to_show = 10# Network Parametersn_hidden_1 = 256 # 1st layer num featuresn_hidden_2 = 128 # 2nd layer num featuresn_input = 784 # MNIST data input (img shape: 28*28)# tf Graph input (only pictures)X = tf.placeholder("float", [None, n_input])weights = {    'encoder_h1': tf.Variable(tf.random_normal([n_input, n_hidden_1])),    'encoder_h2': tf.Variable(tf.random_normal([n_hidden_1, n_hidden_2])),    'decoder_h1': tf.Variable(tf.random_normal([n_hidden_2, n_hidden_1])),    'decoder_h2': tf.Variable(tf.random_normal([n_hidden_1, n_input])),}biases = {    'encoder_b1': tf.Variable(tf.random_normal([n_hidden_1])),    'encoder_b2': tf.Variable(tf.random_normal([n_hidden_2])),    'decoder_b1': tf.Variable(tf.random_normal([n_hidden_1])),    'decoder_b2': tf.Variable(tf.random_normal([n_input])),}# Building the encoderdef encoder(x):    # Encoder Hidden layer with sigmoid activation #1    layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(x, weights['encoder_h1']),                                   biases['encoder_b1']))    # Decoder Hidden layer with sigmoid activation #2    layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1, weights['encoder_h2']),                                   biases['encoder_b2']))    return layer_2# Building the decoderdef decoder(x):    # Encoder Hidden layer with sigmoid activation #1    layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(x, weights['decoder_h1']),                                   biases['decoder_b1']))    # Decoder Hidden layer with sigmoid activation #2    layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1, weights['decoder_h2']),                                   biases['decoder_b2']))    return layer_2# Construct modelencoder_op = encoder(X)                     decoder_op = decoder(encoder_op)                                                               # Predictiony_pred = decoder_op# Targets (Labels) are the input data.y_true = X# Define loss and optimizer, minimize the squared errorcost = tf.reduce_mean(tf.pow(y_true - y_pred, 2))optimizer = tf.train.RMSPropOptimizer(learning_rate).minimize(cost)# Initializing the variablesinit = tf.global_variables_initializer()# Launch the graphwith tf.Session() as sess:    sess.run(init)    total_batch = int(mnist.train.num_examples/batch_size)    # Training cycle    for epoch in range(training_epochs):        # Loop over all batches        for i in range(total_batch):            batch_xs, batch_ys = mnist.train.next_batch(batch_size)            # Run optimization op (backprop) and cost op (to get loss value)            _, c = sess.run([optimizer, cost], feed_dict={X: batch_xs})        # Display logs per epoch step        if epoch % display_step == 0:            print("Epoch:", '%04d' % (epoch+1),                  "cost=", "{:.9f}".format(c))    print("Optimization Finished!")    # Applying encode and decode over test set    encode_decode = sess.run(        y_pred, feed_dict={X: mnist.test.images[:examples_to_show]})    # Compare original images with their reconstructions    f, a = plt.subplots(2, 10, figsize=(10, 2))    for i in range(examples_to_show):        a[0][i].imshow(np.reshape(mnist.test.images[i], (28, 28)))        a[1][i].imshow(np.reshape(encode_decode[i], (28, 28)))    plt.show(f)    plt.draw()    plt.waitforbuttonpress()



Extracting /tmp/data/train-images-idx3-ubyte.gzExtracting /tmp/data/train-labels-idx1-ubyte.gzExtracting /tmp/data/t10k-images-idx3-ubyte.gzExtracting /tmp/data/t10k-labels-idx1-ubyte.gzEpoch: 0001 cost= 0.202572390Epoch: 0002 cost= 0.170586497Epoch: 0003 cost= 0.147145674Epoch: 0004 cost= 0.134183317Epoch: 0005 cost= 0.129831925Epoch: 0006 cost= 0.125962004Epoch: 0007 cost= 0.118559957Epoch: 0008 cost= 0.112685442Epoch: 0009 cost= 0.109631285Epoch: 0010 cost= 0.105650857Epoch: 0011 cost= 0.103282064Epoch: 0012 cost= 0.101963483Epoch: 0013 cost= 0.099665105Epoch: 0014 cost= 0.100522719Epoch: 0015 cost= 0.097489521Epoch: 0016 cost= 0.093438946Epoch: 0017 cost= 0.093155362Epoch: 0018 cost= 0.091413260Epoch: 0019 cost= 0.090430483Epoch: 0020 cost= 0.090419836Optimization Finished!
<matplotlib.figure.Figure at 0x7fcb187e0278>
/home/wgb/anaconda3/lib/python3.6/site-packages/matplotlib/figure.py:402: UserWarning: matplotlib is currently using a non-GUI backend, so cannot show the figure  "matplotlib is currently using a non-GUI backend, "
---------------------------------------------------------------------------NotImplementedError                       Traceback (most recent call last)<ipython-input-14-ebd2d41d806d> in <module>()    108     plt.show(f)    109     plt.draw()--> 110     plt.waitforbuttonpress()/home/wgb/anaconda3/lib/python3.6/site-packages/matplotlib/pyplot.py in waitforbuttonpress(*args, **kwargs)    724     If *timeout* is negative, does not timeout.    725     """--> 726     return gcf().waitforbuttonpress(*args, **kwargs)    727     728 /home/wgb/anaconda3/lib/python3.6/site-packages/matplotlib/figure.py in waitforbuttonpress(self, timeout)   1681    1682         blocking_input = BlockingKeyMouseInput(self)-> 1683         return blocking_input(timeout=timeout)   1684    1685     def get_default_bbox_extra_artists(self):/home/wgb/anaconda3/lib/python3.6/site-packages/matplotlib/blocking_input.py in __call__(self, timeout)    374         """    375         self.keyormouse = None--> 376         BlockingInput.__call__(self, n=1, timeout=timeout)    377     378         return self.keyormouse/home/wgb/anaconda3/lib/python3.6/site-packages/matplotlib/blocking_input.py in __call__(self, n, timeout)    115         try:    116             # Start event loop--> 117             self.fig.canvas.start_event_loop(timeout=timeout)    118         finally:  # Run even on exception like ctrl-c    119             # Disconnect the callbacks/home/wgb/anaconda3/lib/python3.6/site-packages/matplotlib/backend_bases.py in start_event_loop(self, timeout)   2412         This is implemented only for backends with GUIs.   2413         """-> 2414         raise NotImplementedError   2415    2416     def stop_event_loop(self):NotImplementedError: 


0 0
原创粉丝点击