C++ API载入tensorflow graph

来源:互联网 发布:淘宝美工设计培训 编辑:程序博客网 时间:2024/06/05 14:13

通过C++ API载入tensorflow graph

在tensorflow repo中,和C++相关的tutorial远没有python的那么详尽。这篇文章主要介绍如何利用C++来载入一个预训练好的graph,以便于单独使用或者嵌入到其他app中。

Requirements

  • 安装bazel:tensorflow是使用bazel来进行编译的,所以如果要编译其他需要用到tensorflow的文件,我们就需要用到bazel。关于bazel,如果想要了解更多,可以参考我的另外两篇博客:Bazel入门:编译C++项目,Bazel入门2:C++编译常见用例。

  • Clone TensorFlow repo。

    git clone --recursive https://github.com/tensorflow/tensorflow

构建graph

我们首先创建一个tensorflow graph,然后保存成protobuf备用。

import tensorflow as tfimport numpy as npwith tf.Session() as sess:    a = tf.Variable(5.0, name='a')    b = tf.Variable(6.0, name='b')    c = tf.multiply(a, b, name="c")    sess.run(tf.global_variables_initializer())    print a.eval() # 5.0    print b.eval() # 6.0    print c.eval() # 30.0    tf.train.write_graph(sess.graph_def, 'models/', 'graph.pb', as_text=False)

创建二进制文件

让我们在tensorflow/tensorflow目录下创建一个名叫loader的目录,即tensorflow/tensorflow/loader,用于载入之前我们创建好的graph。

loader/目录下我们再创建一个新的文件叫做loader.cc。在loader.cc里我们要做以下几件事情:

  1. 初始化一个tensorflow session
  2. 载入之前我们创建好的graph
  3. 将这个graph加入到session里面
  4. 设置好输入输出
  5. 运行graph,得到输出
  6. 读取输出中的值
  7. 关闭session,释放资源
#include "tensorflow/core/public/session.h"#include "tensorflow/core/platform/env.h"using namespace tensorflow;int main(int argc, char* argv[]) {  // Initialize a tensorflow session  Session* session;  Status status = NewSession(SessionOptions(), &session);  if (!status.ok()) {    std::cout << status.ToString() << "\n";    return 1;  }  // Read in the protobuf graph we exported  // (The path seems to be relative to the cwd. Keep this in mind  // when using `bazel run` since the cwd isn't where you call  // `bazel run` but from inside a temp folder.)  GraphDef graph_def;  status = ReadBinaryProto(Env::Default(), "models/graph.pb", &graph_def);  if (!status.ok()) {    std::cout << status.ToString() << "\n";    return 1;  }  // Add the graph to the session  status = session->Create(graph_def);  if (!status.ok()) {    std::cout << status.ToString() << "\n";    return 1;  }  // Setup inputs and outputs:  // Our graph doesn't require any inputs, since it specifies default values,  // but we'll change an input to demonstrate.  Tensor a(DT_FLOAT, TensorShape());  a.scalar<float>()() = 3.0;  Tensor b(DT_FLOAT, TensorShape());  b.scalar<float>()() = 2.0;  std::vector<std::pair<string, tensorflow::Tensor>> inputs = {    { "a", a },    { "b", b },  };  // The session will initialize the outputs  std::vector<tensorflow::Tensor> outputs;  // Run the session, evaluating our "c" operation from the graph  status = session->Run(inputs, {"c"}, {}, &outputs);  if (!status.ok()) {    std::cout << status.ToString() << "\n";    return 1;  }  // Grab the first output (we only evaluated one graph node: "c")  // and convert the node to a scalar representation.  auto output_c = outputs[0].scalar<float>();  // (There are similar methods for vectors and matrices here:  // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/public/tensor.h)  // Print the results  std::cout << outputs[0].DebugString() << "\n"; // Tensor<type: float shape: [] values: 30>  std::cout << output_c() << "\n"; // 30  // Free any resources used by the session  session->Close();  return 0;}

然后我们需要为我们的项目创建一个BUILD文件,这会告诉bazel要编译什么东西。在BUILD文件里我们要定义一个cc_binary,表示输出一个二进制文件。

cc_binary(    name = "loader",    srcs = ["loader.cc"],    deps = [        "//tensorflow/core:tensorflow",    ])

那么最终文件结构如下:

  • tensorflow/tensorflow/loader/
  • tensorflow/tensorflow/loader/loader.cc
  • tensorflow/tensorflow/loader/BUILD

编译和运行

  • 在tensorflow repo的根目录下,运行./configure
  • 在tensorflow/tensorflow/loader目录下,运行bazel build :loader
    • 如果编译的时候遇到一大串undefined reference to ...的话建议用bazel build —config=monolithic :loader编译,参考https://github.com/tensorflow/tensorflow/issues/13267
  • 在tensorflow repo的根目录下,cd到 bazel-bin/tensorflow/loader目录下
  • 将graph protobuf 拷贝到models/graph.pb
  • 运行./loader,得到输出!

Reference

  1. Loading a TensorFlow graph with the C++ API
  2. tensorflow#issue:Packaged TensorFlow C++ library for bazel-independent use
原创粉丝点击