TensorFlow:top_k()和区别in_top_k()

来源:互联网 发布:亚丝娜的剑数据 编辑:程序博客网 时间:2024/06/06 07:38

      从字面上可以大致了解这两个函数的自用,但具体的作用,还需要查看源码,及编程实现,这样掌握和了解的比较透彻。

in_top_k(predictions, targets, k, name=None):
r"""Says whether the targets are in the top `K` predictions.
     首先说作用:返回一个布尔向量,说明目标值是否存在于预测值之中。

  参数:predicitions:输入的输入tensor,数据类型必须是以下之一:float32、float64、int32、int64、uint8、int16、int8。

              targets:tensor,数据类型是 int32 。每行目标值所在的位置,如果predicitions某行的最大值位置为n, n==targets,则该行的返回值为True

              k: 最大值的个数,k值关系返回矩阵的结果。如果k=1,最大值的位置是否在targets处。

例如:

    input = tf.constant(np.random.rand(3,4), tf.float32)    k = 1     output = tf.nn.in_top_k(input, [3,3,3], k)#每一行的最大值都在第3列(0为第一列)    with tf.Session() as sess:        print(sess.run(input))        print(sess.run(output))
输出:

[[ 0.46714601  0.92652822  0.16808732  0.44906664]#最大值在第1列,返回为false [ 0.03874864  0.55331773  0.32944077  0.84536946]#最大值在第3列,返回为false [ 0.80283058  0.63945484  0.07212774  0.27699497]]最大值在第1列,返回为false[False  True False]
    如果k=3呢?

[[ 0.10950958  0.09272877  0.65265322  0.49682239]#最大值在第二列,第二个次大值在第3列,返回true
 [ 0.70769322  0.00581258  0.40589932  0.7010119 ]
 [ 0.18922156  0.57137531  0.14654963  0.26083347]]
[ True  True  True]

top_k(input, k=1, sorted=True, name=None):
"""Finds values and indices of the `k` largest entries for the last dimension.
   作用:返回 input 中每行最大k 个数的值,并且返回它们所在位置的索引。

   参数:input:输入的输入tensor,数据类型必须是以下之一:float32、float64、int32、int64、uint8、int16、int8。

例如:

    input = tf.constant(np.random.rand(3,4), tf.float32)    k = 1  #targets对应的索引是否在最大的前k(2)个数据中    output = tf.nn.top_k(input, k)    with tf.Session() as sess:        print(sess.run(input))        print(sess.run(output))
   输出:

TopKV2(values=array([[ 0.87421292],       [ 0.96415848],       [ 0.54568386]], dtype=float32), indices=array([[3],       [1],       [2]]))#每一行的最大值,与最大值所在的位置。
如果k=2

[[ 0.98679858  0.09883292  0.19342254  0.20967487] [ 0.12573749  0.60547918  0.54529655  0.08391853] [ 0.80146015  0.38433447  0.68723434  0.04177354]]TopKV2(values=array([[ 0.98679858,  0.20967487],       [ 0.60547918,  0.54529655],       [ 0.80146015,  0.68723434]], dtype=float32), indices=array([[0, 3],       [1, 2],       [0, 2]]))#每行中,前两个最大值,及它们所在的位置。




原创粉丝点击