Keras搭建的自编码模型

来源:互联网 发布:如何上传网站源码 编辑:程序博客网 时间:2024/06/07 02:45
  1. http://blog.csdn.net/u012458963/article/details/72566596
         https://kiseliu.github.io/2016/08/16/building-autoencoders-in-keras/


  1. import numpy as np  
  2. np.random.seed(1337)  # for reproducibility  
  3.   
  4. from keras.datasets import mnist  
  5. from keras.models import Model #泛型模型  
  6. from keras.layers import Dense, Input  
  7. import matplotlib.pyplot as plt  
  8.   
  9. # X shape (60,000 28x28), y shape (10,000, )  
  10. (x_train, _), (x_test, y_test) = mnist.load_data()  
  11.   
  12. # 数据预处理  
  13. x_train = x_train.astype('float32') / 255. - 0.5       # minmax_normalized  
  14. x_test = x_test.astype('float32') / 255. - 0.5         # minmax_normalized  
  15. x_train = x_train.reshape((x_train.shape[0], -1))  
  16. x_test = x_test.reshape((x_test.shape[0], -1))  
  17. print(x_train.shape)  
  18. print(x_test.shape)  
  19.   
  20. # 压缩特征维度至2维  
  21. encoding_dim = 2  
  22.   
  23. # this is our input placeholder  
  24. input_img = Input(shape=(784,))  
  25.   
  26. # 编码层  
  27. encoded = Dense(128, activation='relu')(input_img)  
  28. encoded = Dense(64, activation='relu')(encoded)  
  29. encoded = Dense(10, activation='relu')(encoded)  
  30. encoder_output = Dense(encoding_dim)(encoded)  
  31.   
  32. # 解码层  
  33. decoded = Dense(10, activation='relu')(encoder_output)  
  34. decoded = Dense(64, activation='relu')(decoded)  
  35. decoded = Dense(128, activation='relu')(decoded)  
  36. decoded = Dense(784, activation='tanh')(decoded)  
  37.   
  38. # 构建自编码模型  
  39. autoencoder = Model(inputs=input_img, outputs=decoded)  
  40.   
  41. # 构建编码模型  
  42. encoder = Model(inputs=input_img, outputs=encoder_output)  
  43.   
  44. # compile autoencoder  
  45. autoencoder.compile(optimizer='adam', loss='mse')  
  46.   
  47. # training  
  48. autoencoder.fit(x_train, x_train, epochs=20, batch_size=256, shuffle=True)  
  49.   
  50. # plotting  
  51. encoded_imgs = encoder.predict(x_test)  
  52. plt.scatter(encoded_imgs[:, 0], encoded_imgs[:, 1], c=y_test, s=3)  
  53. plt.colorbar()  
  54. plt.show()  
------------------------------------------------------------
 http://blog.csdn.net/u012458963/article/details/72566596