mxnet保存模型,加载模型来预测新数据

来源:互联网 发布:电脑游戏截图软件 编辑:程序博客网 时间:2024/06/05 11:40

mxnet保存模型,以及用模型来预测新数据

我们希望训练好之后的模型,可以保存下来,然后需要预测新数据的时候,就可以拿来用,可以这样做。 

  我们以线性回归的例子来讲: 
1,训练并保存模型。

import mxnet as mximport numpy as npimport logginglogging.getLogger().setLevel(logging.DEBUG)# Training datatrain_data = np.random.uniform(0, 1, [100, 2])train_label = np.array([train_data[i][0] + 2 * train_data[i][1] for i in range(100)])batch_size = 1num_epoch=5# Evaluation Dataeval_data = np.array([[7,2],[6,10],[12,2]])eval_label = np.array([11,26,16])train_iter = mx.io.NDArrayIter(train_data,train_label, batch_size, shuffle=True,label_name='lin_reg_label')eval_iter = mx.io.NDArrayIter(eval_data, eval_label, batch_size, shuffle=False)X = mx.sym.Variable('data')Y = mx.sym.Variable('lin_reg_label')fully_connected_layer  = mx.sym.FullyConnected(data=X, name='fc1', num_hidden = 1)lro = mx.sym.LinearRegressionOutput(data=fully_connected_layer, label=Y, name="lro")model = mx.mod.Module(    symbol = lro ,    data_names=['data'],    label_names = ['lin_reg_label']  # network structure)model.fit(train_iter, eval_iter,            optimizer_params={'learning_rate':0.005, 'momentum': 0.9},            num_epoch=50,            eval_metric='mse',)model.predict(eval_iter).asnumpy()metric = mx.metric.MSE()model.score(eval_iter, metric)keys = model.get_params()[0].keys() # 列出所有权重名称print(keys)conv_w = model.get_params()[0]['fc1_weight'] # 获取想要查看的权重信息,如conv_weightbias = model.get_params()[0]['fc1_bias']print(conv_w.asnumpy()) # 查看具体数值print(bias.asnumpy())# save model, test stands for prefix of model, num_epoch stands for the epoch number of the modelmodel.save_checkpoint('test',num_epoch) # 保存模型 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40

  运行结果为:

dict_keys(['fc1_weight', 'fc1_bias'])[[ 0.99999714  1.99999332]]INFO:root:Saved checkpoint to "test-0005.params"
  • 1
  • 2
  • 3

  被保存下来的文件分别是: 
    test-symbol.json 
    test-num_epoch.params 
  

2,下载模型并使用。

import mxnet as mximport numpy as npbatch_size = 1num_batch = 5# Adding 0.1 to each of the valueseval_data = np.array([[7,2],[6,10],[12,2]])eval_label = np.array([11.1,26.1,16.1]) eval_iter = mx.io.NDArrayIter(eval_data, eval_label, batch_size, shuffle=False)# load modelsym,arg_params,aux_params = mx.model.load_checkpoint('test', 5)mod = mx.mod.Module(symbol=sym,context=mx.gpu(),data_names=['data'],label_names=['lin_reg_label'])mod.bind(for_training=False,data_shapes=[('data', (1, 2))])mod.set_params(arg_params,aux_params)# use modelpredict_stress = mod.predict(eval_iter, num_batch)print(predict_stress.asnumpy())
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

  运行结果为:

[[ 10.99997139] [ 25.9999218 ] [ 15.99995708]]
原创粉丝点击