keras-lenet5
来源:互联网 发布:e72i软件下载 编辑:程序博客网 时间:2024/06/16 21:16
from keras.layers import Input, Conv2D, MaxPooling2D, Dense, Flattenfrom keras.models import Modelfrom keras.optimizers import Adamfrom keras.utils import np_utilsfrom keras.callbacks import ModelCheckpointimport numpy as npimport pandas as pdimg_row = 28img_col = 28#读取训练集数据,图片数据格式是(batch, row*col),但是我们在训练的过程中输入是(batch,row,col,channel),所以我们需要处理成指定的格式。def create_train(path): img_matrix = pd.read_csv(path).values total = img_matrix.shape[0] labels = img_matrix[:, 0] x_train = img_matrix[:, 1:] return labels.astype(np.uint8), x_train.reshape(total, img_row, img_col, 1)def create_test(path): img_matrix = pd.read_csv(path).values total = img_matrix.shape[0] return img_matrix.reshape(total, img_row, img_col, 1)def create_lenet(): inputs = Input((img_row, img_col, 1)) conv1 = Conv2D(filters=6, kernel_size=(5, 5), strides=1, padding='valid', activation='relu')(inputs) pool1 = MaxPooling2D(pool_size=(2, 2), strides=None, padding='valid')(conv1) conv2 = Conv2D(filters=16, kernel_size=(5, 5), strides=1, padding='valid', activation='relu')(pool1) pool2 = MaxPooling2D(pool_size=(2, 2), strides=None, padding='valid')(conv2) conv3 = Conv2D(filters=120, kernel_size=(1, 1), strides=1, padding='valid')(pool2) full1 = Flatten()(conv3)#卷积层是(batch,rows,cols,channel),为了全连接需要将其展开得到(batch,rows*cols*channel)的平铺层 full2 = Dense(units=84)(full1) outputs = Dense(units=10, activation='softmax')(full2) model = Model(inputs=[inputs], outputs=[outputs]) model.compile(optimizer=Adam(), loss='categorical_crossentropy', metrics=['accuracy']) return modeldef train_and_predict(): print('-' * 30) print('Loading and preprocessing train data...') print('-' * 30) labels, x_train = create_train('dataset/Digital/train.csv') # Make the value floats in [0;1] instead of int in [0;255] x_train = x_train.astype('float32') x_train /= 255 # convert class vectors to binary class matrices (ie one-hot vectors) labels = np_utils.to_categorical(labels, 10) print('-' * 30) print('Creating and compiling model...') print('-' * 30) lenet5 = create_lenet() model_checkpoint = ModelCheckpoint('weights.h5', monitor='val_loss', save_best_only=True) print('-' * 30) print('Fitting model...') print('-' * 30) lenet5.fit(x_train, labels, batch_size=32, epochs=20, verbose=1, shuffle=True, validation_split=0.2, callbacks=[model_checkpoint]) print('-' * 30) print('Loading and preprocessing test data...') print('-' * 30) x_test = create_test('dataset/Digital/test.csv') x_test = x_test.astype('float32') x_test /= 255 print('-' * 30) print('Loading saved weights...') print('-' * 30) lenet5.load_weights('weights.h5') print('-' * 30) print('Predicting test data...') print('-' * 30) y_pred = lenet5.predict(x_test, verbose=1) np.savetxt('mnist-pred.csv', np.c_[range(1, len(y_pred) + 1), y_pred], delimiter=',', header='ImageId,Label', comments='', fmt='%d') print('-' * 30) print('Saving predicted masks to files...') print('-' * 30)if __name__ == '__main__': train_and_predict()
阅读全文
0 0
- keras-lenet5
- keras
- keras
- keras
- Keras
- keras
- Keras
- keras
- LeNet5的基本结构
- LeNet5的深入解析
- 卷积神经网络之LeNet5
- LeNet5各层伪代码
- tensorflow实现lenet5
- 卷积神经网络LeNet5结构
- 【深度学习】TensorFlow实现LeNet5
- LeNet5训练Mnist回顾总结
- 深度学习--LeNet5参数理解
- TensorFlow MNIST CNN LeNet5模型
- 国庆第一天上班,长夜漫漫,无心睡眠,一大波靓照奉上!
- 【HPUoj】1220
- Java 里如何实现线程间通信?
- 停止Tomcat,控制台没有输出信息
- Maven+Spring+SpringMVC+MyBatis+MySQL 整合SSM框架
- keras-lenet5
- Greedy Gift Givers 贪婪的送礼者
- 使用HTML5技术控制电脑或手机上的摄像头
- 数据结构 P35 算法实现 双向循环链表的创建
- Oracle 中 decode 函数用法
- ffmpeg编译与搭建笔记
- Django中的request与response对象
- Android 绘制中国地图及热点省份分布
- Google Dremel数据模型详解