tensorflow_api_2:tf.argmax( )

来源:互联网 发布:炫浪网络在线阅读手机 编辑:程序博客网 时间:2024/05/16 15:58

tf.argmax( )

  • 函数作用:
    计算矩阵每行或每列最大值的索引

  • 参数:
    tf.argmax(input, axis = None, name = None, dimension = None)
    input:输入tensor
    axis:0表示按列,1表示按行
    name:自定义输出tensor的名称
    dimension:和axis功能一样,默认axis取值优先。

  • 返回:
    行或列最大值的索引,组成的tensor

  • 例子:

test = tf.constant([[1, 2, 3], [2, 3, 4], [3, 4, 5], [4, 5, 6]])test_0 = tf.argmax(test, 0)   # 按列比较,返回每列最大元素的索引test_1 = tf.argmax(test, 1)   # 按行比较,返回每行最大元素的索引with tf.Session() as sess:    print(sess.run(test))        print()    print(sess.run(test_0))   # 输出[3 3 3]     print()    print(sess.run(test_1))   # 输出[2 2 2 2]