tf.argmax

来源:互联网 发布:java多线程是什么意思 编辑:程序博客网 时间:2024/05/22 03:19

tf.argmax(input, axis=None, name=None, dimension=None)
此函数是对矩阵按行或列计算最大值

参数
input:输入Tensor
axis:0表示按列,1表示按行
name:名称
dimension:和axis功能一样,默认axis取值优先。新加的字段
返回:Tensor 一般是行或列的最大值下标

import tensorflow as tf  a=tf.Variable(tf.random_uniform([3,4],minval=-1,maxval=1))  b=tf.argmax(input=a,axis=0)  c=tf.argmax(input=a,dimension=1)   #此处用dimesion或用axis是一样的  sess = tf.Session() sess.run(tf.initialize_all_variables())  print(sess.run(a))  #[[ 0.04261756 -0.34297419 -0.87816691 -0.15430689]  # [ 0.18663144  0.86972666 -0.06103253  0.38307118]  # [ 0.84588599 -0.45432305 -0.39736366  0.38526249]]  print(sess.run(b))  #[2 1 1 2]  print(sess.run(c))  #[0 1 0]  
0 0
原创粉丝点击