如何查找Tensorflow operation的实现源码

来源:互联网 发布:房地产公司 水暖 知乎 编辑:程序博客网 时间:2024/06/05 07:24

如何查找Tensorflow operation的实现源码

笔者由于工作原因经常需要查阅tensorflow各个operation的实现,然而有些op实在没法猜到它到底定义在那个文件里,全文搜索op的名称又经常搜出来太多的文件,无法快速筛选。

近日笔者研究了一下tensorflow增加新的op的方式,发现了一个查找op实现的好方法。

一般来说,在tensorflow中增加一个新的op需要两步。(以下内容均参考自tensorflow官方文档)

  • 定义这个op的接口,并注册到tensorflow中。
    在接口中定义中,需要指定这个op的input, output以及相关的一些attribute。定义op接口需要调用宏REGISTER_OP,例如:
#include "tensorflow/core/framework/op.h"#include "tensorflow/core/framework/shape_inference.h"using namespace tensorflow;REGISTER_OP("ZeroOut")    .Input("to_zero: int32")    .Output("zeroed: int32")    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {      c->set_output(0, c->input(0));      return Status::OK();    });
  • 注册这个op的实现。需要调用宏REGISTER_KERNEL_BUILDER。例如:
#include "tensorflow/core/framework/op_kernel.h"using namespace tensorflow;class ZeroOutOp : public OpKernel { public:  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}  void Compute(OpKernelContext* context) override {    // Grab the input tensor    const Tensor& input_tensor = context->input(0);    auto input = input_tensor.flat<int32>();    // Create an output tensor    Tensor* output_tensor = NULL;    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),                                                     &output_tensor));    auto output_flat = output_tensor->flat<int32>();    // Set all but the first element of the output tensor to 0.    const int N = input.size();    for (int i = 1; i < N; i++) {      output_flat(i) = 0;    }    // Preserve the first input value if possible.    if (N > 0) output_flat(0) = input(0);  }};REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);

因此,对于一般的op,在TensorFlow源代码中全文搜索REGISTER_OP("OpName")REGISTER_KERNEL_BUILDER(Name("OpName")基本上就可以精确定位到这个op的接口定义和具体实现了。

原创粉丝点击