tensorflow实现P@1和MRR
来源:互联网 发布:哪个软件可以化妆 编辑:程序博客网 时间:2024/06/06 01:36
最近复现别人的论文,才发现tf竟然就只有常用的一些loss, 并没有其他指标,例如P@1和MRR。在手工计算这些指标的过程中,发现了几个神奇的函数。
任务介绍
输入是各个候选的分值tensor_score,假设有每个样本对应5个候选,那么输入大小就是[batch_size, 5]。target指定了候选中实际匹配一个为1,其余为0,大小为[batch_size,1]
现在我们的目标是:得到5个候选从高到低的排序,并且把最高的取出来与target对比,若其target为1则正确,否则错误,计算P@1。然后取targer中值为1的候选,得到它的排名,计算MRR。
具体的计算公式如下:
r()表示正确答案A的排序,
为了方便计算,我把正确答案(即对应标签为1的候选)放在了第一个,剩余4个都是错误答案。
这个任务本身并不难,难的是在神经网络中我想通过矩阵运算的方式(可以并行)而不是用for循环进行元素运算。
原料
找了很久,终于找到了以下四个平时没听过的函数:
- tf.nn.top_k(input,k,sorted=True)
返回[value,index]tuple,用于对tensor_score排序。如果input是个向量的话,则返回其中最大的k个值和相应的序号。如果input是个矩阵,则返回每行最大的k个值和相应的序号矩阵。
sample=[[1,2,3,4],[5,6,7,8],[9,1,2,12]]top_v,top_idx=tf.nn.top_k(sample,k=4)idx=s.run(top_idx)print idx
output:
[[3 2 1 0] [3 2 1 0] [3 0 2 1]]
tf.equal(input,target)
判读输入的矩阵input的值是否等于target,返回一个元素为bool型的矩阵。当target 是一个数值时,可以broadcast.tf.where(input)
查找input矩阵中值为True的元素,并返回一个包含各True元素的序号的矩阵。这个矩阵的维度很奇怪,建议大家看完官方文档后自己写个例子跑跑。通常来说,如果输入一维向量,那么返回的就是一维向量,里面是True元素的序号。如果输入二维矩阵,返回的就是2维矩阵,第一列各True所在的行序号,第二列表示各True元素的列序号。例子:继续上面的idx
where=tf.unpack(tf.where(tf.equal(idx,0))print s.run(where)
output:
[[0 3] [1 3] [2 1]]
- tf.reduce_mean(logits,reduce_indices=None)
这个函数用于在某个维度reduce_indices上计算矩阵的平均值,当不指定时求的是整个矩阵的平均值。举个例子:
p=tf.reduce_mean([0.0,1.0,1.0])q=s.run(p)print q
output:
0.666667
要注意的是,输出的值的类型和输入一致,如果输入是整数,输出也会强制转成整数,那是错误的结果。
P@1和MRR函数
代码如下:
def predict(score_tensor,batch_size, unit_size): ''' :param score_tensor: shape [batch_size, unit_size] :return: ranking of scores ''' ranks = tf.nn.top_k(score_tensor, k=unit_size) pred_idx = tf.argmax(score_tensor,1) p = tf.reduce_mean(tf.cast(tf.equal(pred_idx, 0), "float"), name='p_at_1') true_rank = tf.slice(tf.where(tf.equal(ranks, 0)),[0,1],[batch_size,1]) mrr = tf.reduce_mean(1.0 / true_rank+1) return pred_idx, p, mrr
- tensorflow实现P@1和MRR
- mysql ICP和MRR性能优化测试
- 内存管理机制MRR/MRC和ARC
- p、*p和&p
- alexnet tensorflow 实现和训练
- p、*p和&p
- *p,p和&p区别:
- 内存管理MRR/ARC,property和内存泄露
- mysql5.6中mrr和icp优化简述
- IR的评价指标-MAP,NDCG和MRR
- IR的评价指标-MAP,NDCG和MRR
- mysql5.6中mrr和icp优化简述
- p++和++p
- 关于*p++和++p*
- 指针 *p++和*++p
- *p++和*++p
- p++ 和++p
- *&p和**p
- 设计模式_享元模式
- 惠盈宝
- (转)MFC程序(在静态库中使用MFC)问题
- 欢迎使用CSDN-markdown编辑器
- oracle数据库循环语句
- tensorflow实现P@1和MRR
- 关于hexo搭建博客后的一些基础配置
- Jsp状态管理
- 第七届Java软件开发C组
- java.io.FileNotFoundException: class path resource [applicationContext.xml] cannot be opened because
- [LeetCode]13. Roman to Integer
- Elasticsearch垃圾回收日志
- 第一篇文章·前言
- Makefile编译时怎么打印出变量值