Keras框架训练模型保存及再载入

来源:互联网 发布:北京地铁网络取票机 编辑:程序博客网 时间:2024/06/10 17:42

实验数据MNIST

初次训练模型并保存

import numpy as npfrom keras.datasets import mnistfrom keras.utils import np_utilsfrom keras.models import Sequentialfrom keras.layers import Densefrom keras.optimizers import SGD# 载入数据(x_train,y_train),(x_test,y_test) = mnist.load_data()# (60000,28,28)print('x_shape:',x_train.shape)# (60000)print('y_shape:',y_train.shape)# (60000,28,28)->(60000,784)x_train = x_train.reshape(x_train.shape[0],-1)/255.0x_test = x_test.reshape(x_test.shape[0],-1)/255.0# 换one hot格式y_train = np_utils.to_categorical(y_train,num_classes=10)y_test = np_utils.to_categorical(y_test,num_classes=10)# 创建模型,输入784个神经元,输出10个神经元model = Sequential([        Dense(units=10,input_dim=784,bias_initializer='one',activation='softmax')    ])# 定义优化器sgd = SGD(lr=0.2)# 定义优化器,loss function,训练过程中计算准确率model.compile(    optimizer = sgd,    loss = 'mse',    metrics=['accuracy'],)# 训练模型model.fit(x_train,y_train,batch_size=64,epochs=5)# 评估模型loss,accuracy = model.evaluate(x_test,y_test)print('\ntest loss',loss)print('accuracy',accuracy)# 保存模型model.save('model.h5')   # HDF5文件,pip install h5py

这里写图片描述
这里写图片描述

载入初次训练的模型,再训练

import numpy as npfrom keras.datasets import mnistfrom keras.utils import np_utilsfrom keras.models import Sequentialfrom keras.layers import Densefrom keras.optimizers import SGDfrom keras.models import load_model# 载入数据(x_train,y_train),(x_test,y_test) = mnist.load_data()# (60000,28,28)print('x_shape:',x_train.shape)# (60000)print('y_shape:',y_train.shape)# (60000,28,28)->(60000,784)x_train = x_train.reshape(x_train.shape[0],-1)/255.0x_test = x_test.reshape(x_test.shape[0],-1)/255.0# 换one hot格式y_train = np_utils.to_categorical(y_train,num_classes=10)y_test = np_utils.to_categorical(y_test,num_classes=10)# 载入模型model = load_model('model.h5')# 评估模型loss,accuracy = model.evaluate(x_test,y_test)print('\ntest loss',loss)print('accuracy',accuracy)# 训练模型model.fit(x_train,y_train,batch_size=64,epochs=2)# 评估模型loss,accuracy = model.evaluate(x_test,y_test)print('\ntest loss',loss)print('accuracy',accuracy)# 保存参数,载入参数model.save_weights('my_model_weights.h5')model.load_weights('my_model_weights.h5')# 保存网络结构,载入网络结构from keras.models import model_from_jsonjson_string = model.to_json()model = model_from_json(json_string)print(json_string)

这里写图片描述

原创粉丝点击