基于Theano的深度学习框架keras及配合SVM训练模型
来源:互联网 发布:必修3基本算法语句ppt 编辑:程序博客网 时间:2024/06/06 10:52
1.介绍
Keras是基于Theano的一个深度学习框架,它的设计参考了Torch,用Python语言编写,是一个高度模块化的神经网络库,支持GPU和CPU。keras官方文档地址 地址
2.流程
先使用CNN进行训练,利用Theano函数将CNN全连接层的值取出来,给SVM进行训练
3.结果示例
因为这里只是一个演示keras&SVM的demo,未对参数进行过多的尝试,结果一般
4.代码
由于keras文档、代码更新,目前网上很多代码都不能使用,下面贴上我的代码,可以直接运行
from keras.models import Sequentialfrom keras.layers.core import Dense, Dropout, Activation,Flattenfrom keras.layers.convolutional import Convolution2D, MaxPooling2Dfrom keras.optimizers import SGDfrom keras.datasets import mnistfrom keras.layers import BatchNormalizationfrom sklearn.svm import SVCimport theanofrom keras.utils import np_utilsdef svc(traindata,trainlabel,testdata,testlabel): print("Start training SVM...") svcClf = SVC(C=1.0,kernel="rbf",cache_size=3000) svcClf.fit(traindata,trainlabel) pred_testlabel = svcClf.predict(testdata) num = len(pred_testlabel) accuracy = len([1 for i in range(num) if testlabel[i]==pred_testlabel[i]])/float(num) print("cnn-svm Accuracy:",accuracy)#each add as one layermodel = Sequential()#1 .use convolution,pooling,full connectionmodel.add(Convolution2D(5, 3, 3,border_mode='valid',input_shape=(1, 28, 28),activation='tanh'))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Convolution2D(10, 3, 3,activation='tanh'))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Flatten())model.add(Dense(100,activation='tanh')) #Full connectionmodel.add(Dense(10,activation='softmax'))#2 .just only user full connection# model.add(Dense(100,input_dim = 784, init='uniform',activation='tanh'))# model.add(Dense(100,init='uniform',activation='tanh'))# model.add(Dense(10,init='uniform',activation='softmax'))# sgd = SGD(lr=0.2, decay=1e-6, momentum=0.9, nesterov=True)model.compile(optimizer='sgd', loss='categorical_crossentropy')(X_train, y_train), (X_test, y_test) = mnist.load_data()#change data type,keras category need ont hot#2 reshape#X_train = X_train.reshape(X_train.shape[0],X_train.shape[1]*X_train.shape[2]) #X_train.shape[0] 60000 X_train.shape[1] 28 X_train.shape[2] 28#1 reshapeX_train = X_train.reshape(X_train.shape[0],1,X_train.shape[1],X_train.shape[2])Y_train = np_utils.to_categorical(y_train, 10)#new label for svmy_train_new = y_train[0:42000]y_test_new = y_train[42000:]#new train and test dataX_train_new = X_train[0:42000]X_test = X_train[42000:]Y_train_new = Y_train[0:42000]Y_test = Y_train[42000:]model.fit(X_train_new, Y_train_new, batch_size=200, nb_epoch=100,shuffle=True, verbose=1, show_accuracy=True, validation_split=0.2)print("Validation...")val_loss,val_accuracy = model.evaluate(X_test, Y_test, batch_size=1,show_accuracy=True)print "val_loss: %f" %val_lossprint "val_accuracy: %f" %val_accuracy#define theano funtion to get output of FC layerget_feature = theano.function([model.layers[0].input],model.layers[5].get_output(train=False),allow_input_downcast=False)FC_train_feature = get_feature(X_train_new)FC_test_feature = get_feature(X_test)svc(FC_train_feature,y_train_new,FC_test_feature,y_test_new)
0 0
- 基于Theano的深度学习框架keras及配合SVM训练模型
- 基于Theano的深度学习框架keras及配合SVM训练模型
- 基于Theano的深度学习框架keras及配合SVM训练模型 (非常好的思路:DL+DM)
- 基于Theano的深度学习(Deep Learning)框架Keras学习随笔-05-模型
- 基于Theano的深度学习(Deep Learning)框架Keras学习随笔-05-模型
- 基于Theano的深度学习(Deep Learning)框架Keras学习随笔-05-模型
- 基于theano的深度学习框架Keras的使用
- 基于Theano的深度学习(Deep Learning)框架Keras
- keras基于theano和tensorflow训练的模型相互转换
- keras深度学习框架的训练保存及调用
- 深度学习框架keras安装(后端基于Tensorflow/theano)
- 基于Theano的深度学习(Deep Learning)框架Keras学习随笔-01-FAQ
- 基于Theano的深度学习(Deep Learning)框架Keras学习随笔-02-Example
- 基于Theano的深度学习(Deep Learning)框架Keras学习随笔-03-优化器
- 基于Theano的深度学习(Deep Learning)框架Keras学习随笔-04-目标函数
- 基于Theano的深度学习(Deep Learning)框架Keras学习随笔-06-激活函数
- 基于Theano的深度学习(Deep Learning)框架Keras学习随笔-07-初始化权值
- 基于Theano的深度学习(Deep Learning)框架Keras学习随笔-08-规则化(规格化)
- Java中Set类初始化问题
- Hadoop安装及开发
- jetbrains系列IDE-Vmoptions 优化指南
- IAR的STlink下载出现 Failed to set configuration with MCU name STM8S207MB: SWIM error [30006]:解决办法
- Django的安装
- 基于Theano的深度学习框架keras及配合SVM训练模型
- Cocos2d-x学习(3) - cocos2d坐标系,锚点
- 22.复杂链表的复制(做第二遍时感觉仍有难度,第三次做还是要看思路)
- 结构体字节对齐
- VMware中centos6.7中设置静态IP
- HTML DOM querySelector() 方法
- HDU 1227 dp距离和最小,中位数的应用
- 开源日历控件DatePicker源码解析
- Umeng的手动的去刷新更新