keras迁移学习 使用vgg16进行手写数字识别
来源:互联网 发布:安卓5.0 源码 编辑:程序博客网 时间:2024/06/06 12:37
一个简单的迁移学习案例:使用keras 将vgg16用于手写数字识别
# -*- coding: utf-8 -*-"""Created on Tue Nov 21 22:26:20 2017@author: www"""from keras.models import Modelfrom keras.layers import Dense, Flatten, Dropoutimport cv2from keras import datasetsfrom keras.applications.vgg16 import VGG16from keras.optimizers import SGDfrom keras.datasets import mnistimport numpy as np#迁移学习 使用VGG16进行手写数字识别#只迁移网络结构,不迁移权重model_vgg = VGG16(include_top=False, weights='imagenet', input_shape=(224,224,3))model = Flatten(name='Flatten')(model_vgg.output)moel = Dense(10, activation='softmax')(model)model_vgg_mnist = Model(inputs=model_vgg.input, outputs=model, name='vgg16')model_vgg_mnist.summary()#迁移学习;网络结构与权重#ishape = 224model_vgg = VGG16(include_top=False, weights='imagenet', input_shape=(ishape, ishape, 3))for layers in model_vgg.layers: layers.trainable = Falsemodel = Flatten()(model_vgg.output)model = Dense(10, activation='softmax')(model)model_vgg_mnist_pretrain = Model(inputs=model_vgg.input, outputs=model, name='vgg16_pretrain')model_vgg_mnist_pretrain.summary()#==============================================================================# Total params: 14,965,578.0# Trainable params: 250,890.0# Non-trainable params: 14,714,688.0#==============================================================================sgd = SGD(lr=0.05, decay=1e-5)model_vgg_mnist_pretrain.compile(optimizer=sgd, loss='categorical_crossentropy', metrics=['accuracy'])(X_train,y_train),(X_test,y_test) = mnist.load_data()#转成VGG16需要的格式X_train = [cv2.cvtColor(cv2.resize(i,(ishape,ishape)), cv2.COLOR_GRAY2BGR) for i in X_train]X_train = np.concatenate([arr[np.newaxis] for arr in X_train]).astype('float32')X_test = [cv2.cvtColor(cv2.resize(i,(ishape,ishape)), cv2.COLOR_GRAY2BGR) for i in X_test ]X_test = np.concatenate([arr[np.newaxis] for arr in X_test] ).astype('float32')#预处理X_train.shapeX_test.shapeX_train /= X_train/255X_test /= X_test/255np.where(X_train[0] != 0)#哑编码def train_y(y): y_one = np.zeros(10) y_one[y] = 1 return y_one y_train_one = np.array([train_y(y_train[i]) for i in range(len(y_train))])y_test_one = np.array([train_y(y_test [i]) for i in range(len(y_test ))])#模型训练model_vgg_mnist_pretrain.fit(X_train, y_train_one, validation_data=(X_test, y_test_one), epochs=200, batch_size=128)
阅读全文
0 0
- keras迁移学习 使用vgg16进行手写数字识别
- keras入门 利用卷积神经网络进行手写数字识别
- keras 入门 --手写数字识别
- Kaggle Digit Recognizer使用keras实现手写数字识别 A1
- Keras入门课2 -- 使用CNN识别mnist手写数字
- Keras 浅尝之MNIST手写数字识别
- keras入门实战:手写数字识别
- keras 实现CNN 进行手写字符识别
- 关于利用机器学习进行手写数字的的识别
- 使用Caffe进行手写数字识别执行流程解析
- matlab 使用libsvm工具箱进行手写数字识别
- 使用逻辑回归和神经网络进行手写数字识别
- [DL]2.使用Softmax回归进行手写数字识别
- keras 手把手入门#1-MNIST手写数字识别 深度学习实战闪电入门
- 在Kaggle手写数字数据集上使用Spark MLlib的RandomForest进行手写数字识别
- Caffe学习-手写数字识别
- Caffe学习-手写数字识别
- Caffe学习-手写数字识别
- 使用通配符 --->弥补泛型擦除的不足
- 【数据库】5索引、视图、触发器
- UVALive-7500-Boxes and Balls
- 紫猫安卓按键之表
- Python函数参数类型*、**的区别
- keras迁移学习 使用vgg16进行手写数字识别
- Django缓存系统
- AS 2293.3_2005 (Emergency lighting)应急灯SAA认证
- 如何用c打印出一颗心
- 6.6
- frag嵌套+pull+Xlvdm
- RPG游戏黑暗之光Part5:技能设定与学习
- 数据结构:线性表
- 封装类