Keras框架训练模型保存及再载入
来源:互联网 发布:北京地铁网络取票机 编辑:程序博客网 时间:2024/06/10 17:42
实验数据MNIST
初次训练模型并保存
import numpy as npfrom keras.datasets import mnistfrom keras.utils import np_utilsfrom keras.models import Sequentialfrom keras.layers import Densefrom keras.optimizers import SGD# 载入数据(x_train,y_train),(x_test,y_test) = mnist.load_data()# (60000,28,28)print('x_shape:',x_train.shape)# (60000)print('y_shape:',y_train.shape)# (60000,28,28)->(60000,784)x_train = x_train.reshape(x_train.shape[0],-1)/255.0x_test = x_test.reshape(x_test.shape[0],-1)/255.0# 换one hot格式y_train = np_utils.to_categorical(y_train,num_classes=10)y_test = np_utils.to_categorical(y_test,num_classes=10)# 创建模型,输入784个神经元,输出10个神经元model = Sequential([ Dense(units=10,input_dim=784,bias_initializer='one',activation='softmax') ])# 定义优化器sgd = SGD(lr=0.2)# 定义优化器,loss function,训练过程中计算准确率model.compile( optimizer = sgd, loss = 'mse', metrics=['accuracy'],)# 训练模型model.fit(x_train,y_train,batch_size=64,epochs=5)# 评估模型loss,accuracy = model.evaluate(x_test,y_test)print('\ntest loss',loss)print('accuracy',accuracy)# 保存模型model.save('model.h5') # HDF5文件,pip install h5py
载入初次训练的模型,再训练
import numpy as npfrom keras.datasets import mnistfrom keras.utils import np_utilsfrom keras.models import Sequentialfrom keras.layers import Densefrom keras.optimizers import SGDfrom keras.models import load_model# 载入数据(x_train,y_train),(x_test,y_test) = mnist.load_data()# (60000,28,28)print('x_shape:',x_train.shape)# (60000)print('y_shape:',y_train.shape)# (60000,28,28)->(60000,784)x_train = x_train.reshape(x_train.shape[0],-1)/255.0x_test = x_test.reshape(x_test.shape[0],-1)/255.0# 换one hot格式y_train = np_utils.to_categorical(y_train,num_classes=10)y_test = np_utils.to_categorical(y_test,num_classes=10)# 载入模型model = load_model('model.h5')# 评估模型loss,accuracy = model.evaluate(x_test,y_test)print('\ntest loss',loss)print('accuracy',accuracy)# 训练模型model.fit(x_train,y_train,batch_size=64,epochs=2)# 评估模型loss,accuracy = model.evaluate(x_test,y_test)print('\ntest loss',loss)print('accuracy',accuracy)# 保存参数,载入参数model.save_weights('my_model_weights.h5')model.load_weights('my_model_weights.h5')# 保存网络结构,载入网络结构from keras.models import model_from_jsonjson_string = model.to_json()model = model_from_json(json_string)print(json_string)
阅读全文
0 0
- Keras框架训练模型保存及再载入
- 保存Keras训练的模型
- keras深度学习框架的训练保存及调用
- Joone框架的神经网络如何保存和载入训练好的模型
- 基于Theano的深度学习框架keras及配合SVM训练模型
- 基于Theano的深度学习框架keras及配合SVM训练模型
- 如何保存Keras模型
- 如何保存Keras模型?
- 如何保存Keras模型
- keras如何保存模型
- 如何保存keras模型
- 基于Theano的深度学习框架keras及配合SVM训练模型 (非常好的思路:DL+DM)
- Keras中实现mnist神经网络训练与模型保存(采用LeNet-5模型)
- ubuntu中利用h5py保存训练好的keras 神经网络模型
- Ubuntu中利用h5py保存训练好的keras神经网络模型
- Keras模型的加载和保存、预训练、按层名匹配参数
- keras 大数据的训练,迭代载入内存
- keras 保存模型和加载模型
- 使用w查看系统负载、vmstat、top、sar、nload命令
- 一个漂亮的php验证码类(分享)
- Eclipse 启动tomcat 问题
- ffmpeg: error while loading shared libraries: libavdevice.so.57
- SpringBoot部署到服务器Tomcat添加server.context-path后静态资源、请求等404
- Keras框架训练模型保存及再载入
- HDU-2037
- 最大子列和问题
- 【枚举算法】枚举法概念
- spring中各个模块的作用
- a letter and a number
- LRU、LFU算法java实现
- POJ 1260.Pearls
- 使用git pull文件时和本地文件冲突怎么办