keras实现VGG16 CIFAR10数据集

来源:互联网 发布:电子商务软件教材 编辑:程序博客网 时间:2024/05/16 14:24
import kerasfrom keras.datasets import cifar10from keras.preprocessing.image import ImageDataGeneratorfrom keras.models import Sequentialfrom keras.layers import Dense, Dropout, Activation, Flattenfrom keras.layers import Conv2D, MaxPooling2D, BatchNormalizationfrom keras import optimizersimport numpy as npfrom keras.layers.core import Lambdafrom keras import backend as Kfrom keras.optimizers import SGDfrom keras import regularizers#import data(x_train, y_train), (x_test, y_test) = cifar10.load_data()x_train = x_train.astype('float32')x_test = x_test.astype('float32')y_train = keras.utils.to_categorical(y_train, 10)y_test = keras.utils.to_categorical(y_test, 10)weight_decay = 0.0005nb_epoch=100batch_size=32#layer1 32*32*3model = Sequential()model.add(Conv2D(64, (3, 3), padding='same',input_shape=(32,32,3),kernel_regularizer=regularizers.l2(weight_decay)))model.add(Activation('relu'))model.add(BatchNormalization())model.add(Dropout(0.3))#layer2 32*32*64model.add(Conv2D(64, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))model.add(Activation('relu'))model.add(BatchNormalization())model.add(MaxPooling2D(pool_size=(2, 2)))#layer3 16*16*64model.add(Conv2D(128, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))model.add(Activation('relu'))model.add(BatchNormalization())model.add(Dropout(0.4))#layer4 16*16*128model.add(Conv2D(128, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))model.add(Activation('relu'))model.add(BatchNormalization())model.add(MaxPooling2D(pool_size=(2, 2)))#layer5 8*8*128model.add(Conv2D(256, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))model.add(Activation('relu'))model.add(BatchNormalization())model.add(Dropout(0.4))#layer6 8*8*256model.add(Conv2D(256, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))model.add(Activation('relu'))model.add(BatchNormalization())model.add(Dropout(0.4))#layer7 8*8*256model.add(Conv2D(256, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))model.add(Activation('relu'))model.add(BatchNormalization())model.add(MaxPooling2D(pool_size=(2, 2)))#layer8 4*4*256model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))model.add(Activation('relu'))model.add(BatchNormalization())model.add(Dropout(0.4))#layer9 4*4*512model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))model.add(Activation('relu'))model.add(BatchNormalization())model.add(Dropout(0.4))#layer10 4*4*512model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))model.add(Activation('relu'))model.add(BatchNormalization())model.add(MaxPooling2D(pool_size=(2, 2)))#layer11 2*2*512model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))model.add(Activation('relu'))model.add(BatchNormalization())model.add(Dropout(0.4))#layer12 2*2*512model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))model.add(Activation('relu'))model.add(BatchNormalization())model.add(Dropout(0.4))#layer13 2*2*512model.add(Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay)))model.add(Activation('relu'))model.add(BatchNormalization())model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Dropout(0.5))#layer14 1*1*512model.add(Flatten())model.add(Dense(512,kernel_regularizer=regularizers.l2(weight_decay)))model.add(Activation('relu'))model.add(BatchNormalization())#layer15 512model.add(Dense(512,kernel_regularizer=regularizers.l2(weight_decay)))model.add(Activation('relu'))model.add(BatchNormalization())#layer16 512model.add(Dropout(0.5))model.add(Dense(10))model.add(Activation('softmax'))# 10sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)model.compile(loss='categorical_crossentropy', optimizer=sgd,metrics=['accuracy'])model.fit(x_train,y_train,epochs=nb_epoch, batch_size=batch_size,             validation_split=0.1, verbose=1)

原创粉丝点击