tensorflow添加自定义的auc计算operator
来源:互联网 发布:淘宝囤货小当家 编辑:程序博客网 时间:2024/05/17 12:05
tensorflow添加自定义的auc计算operator
tensorflow可以很方便的添加用户自定义的operator(如果不添加也可以采用sklearn的auc计算函数或者自己写一个但是会在python执行,这里希望在graph中也就是c++端执行这个计算)
这里根据工作需要添加一个计算auc的operator,只给出最简单实现,后续高级功能还是参考官方wiki
https://www.tensorflow.org/versions/r0.7/how_tos/adding_an_op/index.html
注意tensorflow现在和最初的官方wiki有变化,原wiki貌似是需要重新bazel编译整个tensorflow,然后使用比如tf.user_op.auc这样。
目前wiki给出的方式>=0.6.0版本,采用plug-in的方式,更加灵活可以直接用g++编译一个so载入,解耦合,省去了编译tensorflow过程,即插即用。
首先auc的operator计算的文件
tensorflow/core/user_ops/auc.cc
/* Copyright 2015 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// An auc Op.
#include"tensorflow/core/framework/op.h"
#include"tensorflow/core/framework/op_kernel.h"
usingnamespace tensorflow;
usingstd::vector;
//@TODO add weight as optional input
REGISTER_OP("Auc")
.Input("predicts: T1")
.Input("labels: T2")
.Output("z: float")
.Attr("T1: {float, double}")
.Attr("T2: {float, double}")
//.Attr("T1: {float, double}")
//.Attr("T2: {int32, int64}")
.SetIsCommutative()
.Doc(R"doc(
Given preidicts and labels output it's auc
)doc");
classAucOp :public OpKernel {
public:
explicitAucOp(OpKernelConstruction*context) : OpKernel(context) {}
template<typenameValueVec>
voidindex_sort(constValueVec&valueVec, vector<int>&indexVec)
{
indexVec.resize(valueVec.size());
for (size_ti = 0; i < indexVec.size();i++)
{
indexVec[i] =i;
}
std::sort(indexVec.begin(),indexVec.end(),
[&valueVec](constintl,constintr) {returnvalueVec(l) >valueVec(r); });
}
voidCompute(OpKernelContext*context) override {
// Grab the input tensor
const Tensor&predicts_tensor =context->input(0);
const Tensor&labels_tensor =context->input(1);
autopredicts =predicts_tensor.flat<float>();//输入能接受float double那么这里如何都处理?
autolabels =labels_tensor.flat<float>();
vector<int>indexes;
index_sort(predicts,indexes);
typedeffloatFloat;
FloatoldFalsePos = 0;
FloatoldTruePos = 0;
FloatfalsePos = 0;
FloattruePos = 0;
FloatoldOut =std::numeric_limits<Float>::infinity();
Floatresult = 0;
for (size_ti = 0; i < indexes.size();i++)
{
intindex =indexes[i];
Floatlabel =labels(index);
Floatprediction =predicts(index);
Floatweight = 1.0;
//Pval3(label, output, weight);
if (prediction !=oldOut) //存在相同值得情况是特殊处理的
{
result += 0.5 * (oldTruePos + truePos) * (falsePos -oldFalsePos);
oldOut =prediction;
oldFalsePos =falsePos;
oldTruePos =truePos;
}
if (label > 0)
truePos +=weight;
else
falsePos +=weight;
}
result += 0.5 * (oldTruePos + truePos) * (falsePos -oldFalsePos);
FloatAUC =result / (truePos *falsePos);
// Create an output tensor
Tensor*output_tensor =NULL;
TensorShapeoutput_shape;
OP_REQUIRES_OK(context,context->allocate_output(0,output_shape, &output_tensor));
output_tensor->scalar<float>()() =AUC;
}
};
REGISTER_KERNEL_BUILDER(Name("Auc").Device(DEVICE_CPU),AucOp);
编译:
$cat gen-so.sh
TF_INC=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_include())')
TF_LIB=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_lib())')
i=$1
o=${i/.cc/.so}
g++ -std=c++11 -shared $i -o $o -I $TF_INC -l tensorflow_framework -L $TF_LIB -fPIC -Wl,-rpath $TF_LIB
$sh gen-so.sh auc.cc
会生成auc.so
使用的时候
auc_module = tf.load_op_library('auc.so')
#auc = tf.user_ops.auc #0.6.0之前的tensorflow自定义op方式
auc = auc_module.auc
evaluate_op = auc(py_x, Y) #py_x is predicts, Y is labels
- tensorflow添加自定义的auc计算operator
- AUC的计算
- ROC与AUC的计算
- TensorFlow计算AUC错误:Attempting to use uninitialized value auc/false_positives
- AUC计算
- auc计算 代码
- ROC,AUC,Precision,Recall,F1的介绍与计算
- ROC曲线和EER/AUC的计算方式
- 拨开自定义operator new与operator delete的迷雾
- 拨开自定义operator new与operator delete的迷雾
- 拨开自定义operator new与operator delete的迷雾
- 开自定义operator new与operator delete的迷雾
- 拨开自定义operator new与operator delete的迷雾
- 拨开自定义operator new与operator delete的迷雾
- 自定义operator new与operator delete的使用(1)
- 自定义operator new与operator delete的使用(2)
- 拨开自定义operator new与operator delete的迷雾
- 拨开自定义operator new与operator delete的迷雾
- IDEA创建servlet+jstl+jdbc
- Oracle PO全过程/标准流程及分析
- springboot TestNg (一) 环境准备与Helloword
- 关于oracle 检查点队列 脏块的理解
- 区块链BTC98比特币Bitcoin源代码安装编译
- tensorflow添加自定义的auc计算operator
- iOS StoryBoard的转场Segue解读
- JMS学习笔记(二)——使用JMS发送和接受text、Map、Object类型的消息
- 《JavaFx 开发教程》之 环境搭建
- Temporal Action Detection (时序动作检测)方向2017年会议论文整理
- nginx二 之负载均衡搭建
- $.ajax()方法详解
- 如何在mac上安装android sdk
- Java简易五子棋