keras theano 小栗子
来源:互联网 发布:qq游戏端口是多少 编辑:程序博客网 时间:2024/04/30 18:13
数据集:minst
神经网络:三层模型
优化算法:SGD
# coding:utf-8from keras.optimizers import SGDfrom mykeras import minst_pklminst = minst_pkl()train_x, train_y, valid_x, valid_y, test_x, test_y = minst.getIMGSets()from keras.models import Sequentialfrom keras.layers.core import Dense, Activation# modelmodel = Sequential()model.add(Dense(386, input_dim=784, init='uniform'))model.add(Activation('sigmoid'))model.add(Dense(10, activation='softmax'))# compilesgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)model.compile(loss='categorical_crossentropy', optimizer=sgd)# trainmodel.fit(train_x, train_y, nb_epoch=10, batch_size=200)# evaluate# score = model.evaluate(test_x, test_y, batch_size=200)# print "loss=", score# predictprint numpy.argmax(model.predict(test_x[0:5, :]), axis=1)print numpy.argmax(test_y[0:5, :], axis=1)
import cPickleimport gzipimport numpy as npclass minst_pkl(object): def loadIMGData(self, dataset): f = gzip.open(dataset, 'rb') train_set, valid_set, test_set = cPickle.load(f) f.close() return [train_set, valid_set, test_set] def getIMGSets(self): print "load dataset" sets = self.loadIMGData("data/mnist.pkl.gz") train_x, train_y = sets[0] valid_x, valid_y = sets[1] test_x, test_y = sets[2] print "train image,label shape:", train_x.shape, train_y.shape print "valid image,label shape:", valid_x.shape, valid_y.shape print "test image,label shape:", test_x.shape, test_y.shape print "load dataset end" return [train_x, self.transform(train_y), valid_x, self.transform(valid_y), test_x, self.transform(test_y)] def transform(self, data): m = np.shape(data)[0] convert = np.zeros((m, 10)) for i in np.arange(m): convert[i, data[i]] = 1 return convertif __name__ == "__main__": minst = minst_pkl() train_x, train_y, valid_x, valid_y, test_x, test_y = minst.getIMGSets()
实验结果:
Epoch 1/1050000/50000 [==============================] - 2s - loss: 0.5389 Epoch 2/1050000/50000 [==============================] - 2s - loss: 0.3043 Epoch 3/1050000/50000 [==============================] - 2s - loss: 0.2689 Epoch 4/1050000/50000 [==============================] - 2s - loss: 0.2357 Epoch 5/1050000/50000 [==============================] - 2s - loss: 0.2069 Epoch 6/1050000/50000 [==============================] - 2s - loss: 0.1814 Epoch 7/1050000/50000 [==============================] - 2s - loss: 0.1619 Epoch 8/1050000/50000 [==============================] - 2s - loss: 0.1457 Epoch 9/1050000/50000 [==============================] - 2s - loss: 0.1315 Epoch 10/1050000/50000 [==============================] - 2s - loss: 0.1195 [7 2 1 0 4][7 2 1 0 4]
Using Theano backend.
Using gpu device 0: GeForce 820M (CNMeM is disabled, cuDNN not available)
gpu:29.75
cpu:99.99
0 0
- keras theano 小栗子
- win8 64+theano+keras
- 安装Theano和keras
- Anaconda+Theano+Keras安装
- theano + keras 安装
- tensorflow theano keras介绍
- theano和keras安装
- Window10安装theano keras cuda
- CentOS6.5+Theano+Keras安装
- Win7+theano+CUDA+Keras血泪史
- 基于theano的keras安装
- Anaconda+Tensorflow+Theano+Keras安装
- CentOS6.5+Theano+Keras安装
- window-anaconda-theano-keras安装
- ubuntu16.04+keras+theano+GPU
- Anaconda+MINGW+theano+keras安装
- TF,Keras,Cafe,Theano,torch
- theano与keras安装问题
- AMS (1):System Server 启动 AMS
- C# HTML解析工具HtmlAgilityPack使用实例(二)--Web页面
- 记录一下,去掉ScrollView滑动到边缘时出现的弧形阴影效果
- 什么是Hive
- 关于window下设置tomcat服务启动并且设置内存大小
- keras theano 小栗子
- opencv-ios开发笔记12 svm参数设置和自动优选
- HDFS分布式文件系统系列---基础
- 文件学习5
- ReentrantLock及AQS浅谈
- CSS伪元素
- 阿里Weex框架快速体验与环境搭建
- Xcode7 添加PCH文件
- Jquery功能用法大全