C++ API载入tensorflow graph
来源:互联网 发布:淘宝美工设计培训 编辑:程序博客网 时间:2024/06/05 14:13
通过C++ API载入tensorflow graph
在tensorflow repo中,和C++相关的tutorial远没有python的那么详尽。这篇文章主要介绍如何利用C++来载入一个预训练好的graph,以便于单独使用或者嵌入到其他app中。
Requirements
安装bazel:tensorflow是使用bazel来进行编译的,所以如果要编译其他需要用到tensorflow的文件,我们就需要用到bazel。关于bazel,如果想要了解更多,可以参考我的另外两篇博客:Bazel入门:编译C++项目,Bazel入门2:C++编译常见用例。
Clone TensorFlow repo。
git clone --recursive https://github.com/tensorflow/tensorflow
构建graph
我们首先创建一个tensorflow graph,然后保存成protobuf备用。
import tensorflow as tfimport numpy as npwith tf.Session() as sess: a = tf.Variable(5.0, name='a') b = tf.Variable(6.0, name='b') c = tf.multiply(a, b, name="c") sess.run(tf.global_variables_initializer()) print a.eval() # 5.0 print b.eval() # 6.0 print c.eval() # 30.0 tf.train.write_graph(sess.graph_def, 'models/', 'graph.pb', as_text=False)
创建二进制文件
让我们在tensorflow/tensorflow目录下创建一个名叫loader的目录,即tensorflow/tensorflow/loader
,用于载入之前我们创建好的graph。
在loader/
目录下我们再创建一个新的文件叫做loader.cc
。在loader.cc
里我们要做以下几件事情:
- 初始化一个tensorflow session
- 载入之前我们创建好的graph
- 将这个graph加入到session里面
- 设置好输入输出
- 运行graph,得到输出
- 读取输出中的值
- 关闭session,释放资源
#include "tensorflow/core/public/session.h"#include "tensorflow/core/platform/env.h"using namespace tensorflow;int main(int argc, char* argv[]) { // Initialize a tensorflow session Session* session; Status status = NewSession(SessionOptions(), &session); if (!status.ok()) { std::cout << status.ToString() << "\n"; return 1; } // Read in the protobuf graph we exported // (The path seems to be relative to the cwd. Keep this in mind // when using `bazel run` since the cwd isn't where you call // `bazel run` but from inside a temp folder.) GraphDef graph_def; status = ReadBinaryProto(Env::Default(), "models/graph.pb", &graph_def); if (!status.ok()) { std::cout << status.ToString() << "\n"; return 1; } // Add the graph to the session status = session->Create(graph_def); if (!status.ok()) { std::cout << status.ToString() << "\n"; return 1; } // Setup inputs and outputs: // Our graph doesn't require any inputs, since it specifies default values, // but we'll change an input to demonstrate. Tensor a(DT_FLOAT, TensorShape()); a.scalar<float>()() = 3.0; Tensor b(DT_FLOAT, TensorShape()); b.scalar<float>()() = 2.0; std::vector<std::pair<string, tensorflow::Tensor>> inputs = { { "a", a }, { "b", b }, }; // The session will initialize the outputs std::vector<tensorflow::Tensor> outputs; // Run the session, evaluating our "c" operation from the graph status = session->Run(inputs, {"c"}, {}, &outputs); if (!status.ok()) { std::cout << status.ToString() << "\n"; return 1; } // Grab the first output (we only evaluated one graph node: "c") // and convert the node to a scalar representation. auto output_c = outputs[0].scalar<float>(); // (There are similar methods for vectors and matrices here: // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/public/tensor.h) // Print the results std::cout << outputs[0].DebugString() << "\n"; // Tensor<type: float shape: [] values: 30> std::cout << output_c() << "\n"; // 30 // Free any resources used by the session session->Close(); return 0;}
然后我们需要为我们的项目创建一个BUILD
文件,这会告诉bazel要编译什么东西。在BUILD
文件里我们要定义一个cc_binary
,表示输出一个二进制文件。
cc_binary( name = "loader", srcs = ["loader.cc"], deps = [ "//tensorflow/core:tensorflow", ])
那么最终文件结构如下:
- tensorflow/tensorflow/loader/
- tensorflow/tensorflow/loader/loader.cc
- tensorflow/tensorflow/loader/BUILD
编译和运行
- 在tensorflow repo的根目录下,运行./configure
- 在tensorflow/tensorflow/loader目录下,运行bazel build :loader
- 如果编译的时候遇到一大串
undefined reference to ...
的话建议用bazel build —config=monolithic :loader编译,参考https://github.com/tensorflow/tensorflow/issues/13267
- 如果编译的时候遇到一大串
- 在tensorflow repo的根目录下,cd到 bazel-bin/tensorflow/loader目录下
- 将graph protobuf 拷贝到models/graph.pb
- 运行./loader,得到输出!
Reference
- Loading a TensorFlow graph with the C++ API
- tensorflow#issue:Packaged TensorFlow C++ library for bazel-independent use
阅读全文
0 0
- C++ API载入tensorflow graph
- Tensorflow: Graph
- Tensorflow-Graph
- Loading a TensorFlow graph with the C++ API
- tensorflow API简单整理(四、Graph,Operation&Tensor)
- tensorflow检查点载入[转]
- tensorflow之Graph save and restore in python and c++(C++ 中使用tensorflow)
- [TensorFlow]理解Tensorboard Graph
- Tensorflow(2) Graph
- 【Tensorflow】tf.Graph()函数
- tensorflow graph 数据结构
- Tensorflow api
- Tensorflow API
- 如何理解TensorFlow中的Graph
- TensorFlow基础:Graph与Variable
- Tensorflow中Graph的概念
- tensorflow中的session和graph
- TensorFlow学习:Graph和Session
- 计算机网络
- convex hull
- 简单工厂模式
- JavaScript栈内存和堆内存区别
- 第1章 机器学习基础
- C++ API载入tensorflow graph
- word2vec训练中文模型
- jquery的ajax动态下拉列表
- 897BChtholly's request
- 个人收获演讲模板
- http与https之间的区别
- C语言中的static 详细分析
- Go-gin的基本使用
- 查找之顺序查找