tensorflow 加载预训练模型

来源:互联网 发布:单片机小车制作流程 编辑:程序博客网 时间:2024/05/18 15:30

加载预训练模型需要具备两个条件:1.框架结构(知道每一层的名字),2. 预训练好的模型文件.ckpt

加载预训练模型代码如下:

import tensorflowas tf
import numpy as np
weights_1 = tf.Variable(tf.zeros([3,4]))
# weights_2 = tf.Variable(tf.zeros([4,3]))

sess = tf.InteractiveSession()
saver = tf.train.Saver()

saver.restore(sess, '/tmp/checkpoint/model.ckpt')
o_test = np.array([[4.0,3.0, 2.0]], dtype='float32')
label = tf.matmul(o_test, weights_1)
# label = tf.matmul(label, weights_2)
print sess.run(label)




以上为加载模型代码,可以写全变量名,也可只写一部分。可根据输出来定。


其中,model.ckpt训练模型代码如下:

import tensorflowas tf
import numpy as np

i_data = np.array([[5.0,3.0, 2.0]], dtype= 'float32')
i_label= np.array([[15.0,10.0, 22.0]], dtype= 'float32')

weights_1 = tf.Variable(tf.zeros([3,4]))
out_1 = tf.matmul(i_data, weights_1)

weights_2 = tf.Variable(tf.zeros([4,3]))
out = tf.matmul(out_1, weights_2)

init_op = tf.global_variables_initializer()
saver = tf.train.Saver()

loss = tf.reduce_mean(tf.square(out- i_label))
training = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

sess = tf.Session()
sess.run(init_op)
for i in range(20000):
sess.run(training)
save_path = saver.save(sess,'/tmp/checkpoint/model.ckpt')

原创粉丝点击