【深度学习】Tensorflow模型保存与恢复

来源:互联网 发布:c语言大全第四版 pdf 编辑:程序博客网 时间:2024/04/26 22:15

tf.train.Saver()的定义与使用

Saver对象:用于在tf中保存,恢复Session
定义

model_path="/tmp/model.ckpt"saver=tf.train.Saver()

Saver保存操作:saver.save(sess,model_path)

save_path=saver.save(sess,model_path)

Saver恢复操作:saver.restore(sess,save_path)

saver.restore(sess,model_path)

注意事项:
1.tf.train.Saver()定义在Session之前
2.saver.save()和saver.restore()都在Session里进行

tf.train.Saver()使用代码示例

# -*- coding: utf-8 -*-"""Created on Wed Jul 19 22:59:41 2017@author: ZMJ"""import tensorflow as tfimport numpy as npimport matplotlib.pyplot as pltprint "Package Loaded"np.random.seed(1)def f(x,weight,bias):  return x*weight+biasWref=0.7Bref=-0.1n=20noise_var=0.05train_X=np.random.random((n,1))ref_Y=f(train_X,Wref,Bref)train_Y=ref_Y+noise_var*np.random.randn(n,1)model_path="/tmp/linear_model.ckpt"lr=0.01epochs=5000display_step=250n_samples=train_X.sizeplt.subplot(121)plt.axis("equal")plt.plot(train_X[:,0],ref_Y[:,0],"ro",label="Original Data")plt.plot(train_X[:,0],train_Y[:,0],"bo",label="Training Data")plt.title("Sactter Plot of Data")plt.legend(loc="lower right")weight=tf.Variable(np.random.randn(),name="weight")bias=tf.Variable(np.random.randn(),name="bias")x=tf.placeholder(tf.float32,shape=[n_samples,1],name="input")y=tf.placeholder(tf.float32,shape=[n_samples,1],name="output")"""Model"""pred=x*weight+biascost=tf.reduce_mean(tf.pow(pred-y,2))optimizer=tf.train.GradientDescentOptimizer(lr).minimize(cost)init=tf.global_variables_initializer()"""Saver Defination"""saver=tf.train.Saver()"""Run Model in First Session"""with tf.Session() as sess:  sess.run(init)  for epoch in range(500):    l=sess.run(optimizer,feed_dict={x:train_X,y:train_Y})    if epoch%display_step==0:      c=sess.run(cost,feed_dict={x:train_X,y:train_Y})      print "Epoch %s .Cost=%s"%(epoch,c)  print "First Session Compelted!"  save_path=saver.save(sess,model_path)  print "Save Completed,Save Path = %s"%save_path"""Run Model in Second Session"""with tf.Session() as sess:  #sess.run(init)  saver.restore(sess,model_path)  print "Model Restored From %s"%model_path  for epoch in range(epochs-500):    l=sess.run(optimizer,feed_dict={x:train_X,y:train_Y})    if epoch%display_step==0:      c=sess.run(cost,feed_dict={x:train_X,y:train_Y})      print "Epoch %s .Cost=%s"%(epoch,c)  print "Second Session Compelted!"  save_path=saver.save(sess,model_path)  print "Save Completed,Save Path = %s"%save_path  Wop=sess.run(weight)  Bop=sess.run(bias)  fop=f(train_X,Wop,Bop)        plt.subplot(122)  plt.plot()  plt.plot(train_X[:,0],ref_Y[:,0],"ro",label="Original Data")  plt.plot(train_X[:,0],train_Y[:,0],"bo",label="Training Data")  plt.plot(train_X[:,0],fop[:,0],"k-",label="Predicted Line")  plt.title("Predicted Line")  plt.legend(loc="lower right")  plt.show()

打印的日志:

Epoch 0 .Cost=0.269742
Epoch 250 .Cost=0.0531464
First Session Compelted!
Save Completed,Save Path = /tmp/linear_model.ckpt
Model Restored From /tmp/linear_model.ckpt
Epoch 0 .Cost=0.0323754
Epoch 250 .Cost=0.019944
Epoch 500 .Cost=0.0125031
Epoch 750 .Cost=0.00804937
Epoch 1000 .Cost=0.00538358
Epoch 1250 .Cost=0.00378797
Epoch 1500 .Cost=0.00283292
Epoch 1750 .Cost=0.00226127
Epoch 2000 .Cost=0.00191911
Epoch 2250 .Cost=0.00171431
Epoch 2500 .Cost=0.00159173
Epoch 2750 .Cost=0.00151836
Epoch 3000 .Cost=0.00147444
Epoch 3250 .Cost=0.00144815
Epoch 3500 .Cost=0.00143242
Epoch 3750 .Cost=0.001423
Epoch 4000 .Cost=0.00141736
Epoch 4250 .Cost=0.00141399
Second Session Compelted!
Save Completed,Save Path = /tmp/linear_model.ckpt

这里写图片描述

原创粉丝点击