Tensorflow 使用slim框架下的分类模型进行分类
来源:互联网 发布:mac存储图片快捷键 编辑:程序博客网 时间:2024/06/02 05:14
根据TF开发人员是说法Tensorflow对于模型读写的保存和调用的步骤一般如下:Build your graph –> write your graph –> import from written graph –> run compute etc。
以下我们使用slim提供的网络Vgg19作为例子:
1. export inference graph
import tensorflow as tf
from tensorflow.python.platform import gfile
from datasets import dataset_factory
from nets import nets_factory
import nets.vgg as net
slim = tf.contrib.slim
tf.app.flags.DEFINE_string(
‘model_name’, ‘vgg_19’, ‘The name of the architecture to save.’)
tf.app.flags.DEFINE_boolean(
‘is_training’, False,
‘Whether to save out a training-focused version of the model.’)
tf.app.flags.DEFINE_integer(
‘default_image_size’, 224,
‘The image size to use if the model does not define it.’)
tf.app.flags.DEFINE_string(‘dataset_name’, ‘imagenet’,
‘The name of the dataset to use with the model.’)
tf.app.flags.DEFINE_integer(
‘labels_offset’, 0,
‘An offset for the labels in the dataset. This flag is primarily used to ’
‘evaluate the VGG and ResNet architectures which do not use a background ’
‘class for the ImageNet dataset.’)
tf.app.flags.DEFINE_string(
‘output_file’, ’ /log/model_graph.pb’, ‘Where to save the resulting file to.’)
tf.app.flags.DEFINE_string(
‘dataset_dir’, ”, ‘Directory to save intermediate dataset files to’)
FLAGS = tf.app.flags.FLAGS
def main(_):
if not FLAGS.output_file:
raise ValueError(‘You must supply the path to save to with –output_file’)
tf.logging.set_verbosity(tf.logging.INFO)
# checkpoint path
checkpoint_path = “You cpkt model path” # ckpt file obtained during model training or fine-tuning
# set up and load session
sess = tf.Session()
arg_scope = net.vgg_arg_scope()
network_fn = nets_factory.get_network_fn(
FLAGS.model_name,
num_classes=(19 - FLAGS.labels_offset),
is_training=FLAGS.is_training)
if hasattr(network_fn, ‘default_image_size’):
image_size = network_fn.default_image_size
else:
image_size = FLAGS.default_image_size
placeholder = tf.placeholder(name=’input’, dtype=tf.float32,
shape=[1, image_size, image_size, 3])
with slim.arg_scope(arg_scope):
logits, end_points = network_fn(placeholder)
probabilities = tf.nn.softmax(logits)
result = tf.identity(probabilities,’output’)
saver = tf.train.Saver()
saver.restore(sess, checkpoint_path)
with gfile.GFile(FLAGS.output_file, ‘wb’) as f:
f.write(sess.graph_def.SerializeToString())
f.close()
if name == ‘main‘:
tf.app.run()
- freeze model
可以通过bazel执行bazelbuildtensorflow/python/tools:freezegraph bazel-bin/tensorflow/python/tools/freeze_graph \ –input_graph=/your/path/to/model_graph.pb \ # obtained above –input_checkpoint=/your/path/to/vgg-19.ckpt \ –input_binary=true –output_graph=/your/path/to/frozen_graph.pb \ –output_node_names=ouput # output node name defined in inception resnet v2 net
也可以通过python代码执行,如下:
python freeze_graph.py --input_graph=/your/path/to/model_graph.pb --input_checkpoint=/your/path/to/vgg-19.ckpt --input_binary=true --output_graph=/your/path/to/frozen_graph.pb --output_node_names=output
注意:此处model_graph.pb为保存模型的推理图,结果为上一步生成文件。
- inference
import cv2
import numpy as np
from nets import nets_factory
from preprocessing import preprocessing_factory, vgg_preprocessing
import tensorflow as tf
file = r” /data/1.jpg”
eval_image_size = 224 #FLAGS.eval_image_size or network_fn.default_image_size
image_np = cv2.imread(file)
resize to model input image size
image_np = cv2.resize(image_np, (eval_image_size, eval_image_size))
image_np = np.expand_dims(image_np, 0)
load model
with tf.gfile.GFile(’ /log/frozen_graph.pb’) as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name=”)
with tf.Session(graph=graph) as sess:
input_tensor = sess.graph.get_tensor_by_name(“input:0”) # get input tensor
output_tensor = sess.graph.get_tensor_by_name(“output:0”) # get output tensor
logits = sess.run(output_tensor, feed_dict={input_tensor:image_np})
print “Prediciton label index:”, np.argmax(logits, 1)
print “Top 3 Prediciton label index:”, np.argsort(logits[8])
- Tensorflow 使用slim框架下的分类模型进行分类
- Tensorflow 使用slim框架下的分类模型进行分类
- 使用TensorFlow-Slim进行图像分类
- 使用tf-slim的inception_resnet_v2预训练模型进行图像分类
- [Tensorflow]基于slim框架下inception模型的植物识别
- Tensorflow使用slim工具(vgg16模型)实现图像分类与分割
- Tensorflow使用slim工具(vgg16模型)实现图像分类与分割
- 谷歌开源图像分类工具TF-Slim,定义TensorFlow复杂模型
- TensorFlow-Slim图像分类库
- 使用tf-slim的ResNet V1 152和ResNet V2 152预训练模型进行图像分类
- [深度学习框架] Tensorflow上使用CNN进行mnist分类
- TensorFlow学习笔记(11)--【Ubuntu】slim框架下的inception_v4模型的运行、可视化、导出和使用
- TensorFlow-Slim image classification library:TensorFlow-Slim 图像分类库
- 利用Tensorflow的Mobilenet模型在移动端进行舌像识别进行体质分类
- 使用python,在已经配置好的模型下进行imagenet分类
- 使用Tensorflow的slim库进行迁移学习
- 使用caffe训练好的模型进行分类
- java调用tensorflow模型进行图片分类识别
- Linux 软件安装到 /usr,/usr/local/ 还是 /opt 目录?
- TextView长按自由选择复制,弹出popwindow菜单,划线,删除线,做笔记
- 蓝桥杯样题-信用卡号验证
- springmvc框架中实现全选
- 自己写的JAVA实现登录限制原理(SSM框架)
- Tensorflow 使用slim框架下的分类模型进行分类
- FFMPEG录屏软件开发之最终完善
- (洛谷)【P1162】填涂颜色 [广度搜索]
- Ridge回归、Lasso回归、坐标下降法、最小角回归
- [解答]CCF-通信网络-2017
- Spring DI 依赖注入案例(带参数构造方法依赖注入、setter方法依赖注入、p名称空间注入)
- 主机ping不通Virtualbox里的虚拟机
- codeforces 676A Nicholas and Permutation
- 中国“互联网+”数字经济峰会杭州召开 描绘数字经济新版图