TensorFlow插曲--tf.argmax函数

来源:互联网 发布:血源诅咒捏脸美女数据 编辑:程序博客网 时间:2024/05/20 10:56

转自:http://blog.csdn.net/zj360202/article/details/70259999

tf.argmax(input, axis=None, name=None, dimension=None)

此函数是对矩阵按行或列计算最大值

参数
  • input:输入Tensor
  • axis:0表示按列,1表示按行
  • name:名称
  • dimension:和axis功能一样,默认axis取值优先。新加的字段
返回:Tensor  一般是行或列的最大值下标向量

例:
[java]view plain copy
 print?
  1. import tensorflow as tf  
  2.   
  3.   
  4. a=tf.get_variable(name='a',  
  5.                   shape=[3,4],  
  6.                   dtype=tf.float32,  
  7.                   initializer=tf.random_uniform_initializer(minval=-1,maxval=1))  
  8. b=tf.argmax(input=a,axis=0)  
  9. c=tf.argmax(input=a,dimension=1)   #此处用dimesion或用axis是一样的  
  10. sess = tf.InteractiveSession()  
  11. sess.run(tf.initialize_all_variables())  
  12. print(sess.run(a))  
  13. #[[ 0.04261756 -0.34297419 -0.87816691 -0.15430689]  
  14. # [ 0.18663144  0.86972666 -0.06103253  0.38307118]  
  15. # [ 0.84588599 -0.45432305 -0.39736366  0.38526249]]  
  16. print(sess.run(b))  
  17. #[2 1 1 2]  
  18. print(sess.run(c))  
  19. #[0 1 0]  
原创粉丝点击