tensorflow(四)caffe-tensorflow学习记录

来源:互联网 发布:java架构师视频教程 编辑:程序博客网 时间:2024/05/17 01:06

按照Lenet里面的例子进行模型和网络的转换:

LeNet Example

Thanks to @Russell91 for this example

This example showns you how to finetune code from the Caffe MNIST tutorial using Tensorflow.
First, you can convert a prototxt model to tensorflow code:

$ ./convert.py examples/mnist/lenet.prototxt --code-output-path=mynet.py

This produces tensorflow code for the LeNet network in mynet.py. The code can be imported as described below in the Inference section. Caffe-tensorflow also lets you convert .caffemodel weight files to .npy files that can be directly loaded from tensorflow:

$ ./convert.py examples/mnist/lenet.prototxt --caffemodel examples/mnist/lenet_iter_10000.caffemodel --data-output-path=mynet.npy

The above command will generate a weight file named mynet.npy.

Inference:

Once you have generated both the code weight files for LeNet, you can finetune LeNet using tensorflow with

$ ./examples/mnist/finetune_mnist.py

At a high level, finetune_mnist.py works as follows:

# Import the converted model's classfrom mynet import MyNet# Create an instance, passing in the input datanet = MyNet({'data':my_input_data})with tf.Session() as sesh:    # Load the data    net.load('mynet.npy', sesh)    # Forward pass    output = sesh.run(net.get_output(), ...)

经过转换之后的代码:(自己添加的代码)

import numpy as npfrom PIL import Imageimport tensorflow as tfimport syssys.path.append('/home/yang/caffe-tensorflow')sys.path.append('/home/yang/caffe-tensorflow/examples/yolo')import yoloimage = tf.placeholder(tf.float32, [1,448,448,3])net = yolo.yolo({'data': image})image_path = '/home/yang/darknet/data/dog.jpg'im = Image.open(image_path)im_reshape = im.resize((448,448))input = np.array(im_reshape)input = input.reshape((1,448,448,3))input = (input*1.0-127.5)*0.007874015748031496with tf.Session() as sess:    sess.run(tf.initialize_all_variables())    net.load('/home/yang/caffe-tensorflow/examples/yolo/yolo.npy', sess)    output = sess.run(net.get_output(), feed_dict={image: input})file = open('/home/yang/Desktop/result.txt','w')for i in range(1470):    file.write(str(output[0][i])+' ')file.close()

附加mnist测试代码:

import syssys.path.append('/home/yang/caffe-tensorflow/examples/mnist')sys.path.append('/home/yang/tensorflow')sys.path.append('/home/yang/caffe-tensorflow')import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets('/home/yang/data', one_hot=True)from mynet import LeNet as MyNetimage = tf.placeholder(tf.float32, [1, 784])labels = tf.placeholder(tf.float32, [1, 10])input = tf.reshape(image, shape=[-1, 28, 28, 1])net = MyNet({'data': input})with tf.Session() as sess:    sess.run(tf.initialize_all_variables())    net.load('/home/yang/caffe-tensorflow/examples/mnist/mynet.npy', sess)    batch_xs, batch_ys = mnist.train.next_batch(1)    output = sess.run(net.get_output(), feed_dict={image: batch_xs})
0 0
原创粉丝点击