tensorflow关于求最大值所在位置函数解读

来源:互联网 发布:数据库的四个基本概念 编辑:程序博客网 时间:2024/05/21 11:09
import tensorflow as tfimport numpy as npimport randomaa=[]for i in range(42):    aa.append(random.randint(11,100))a=np.array(aa)a=a.reshape(7,6)A=tf.constant(a)B=tf.reshape(A,[-1,2,3])#变成7个2*3c=tf.argmax(B,2)#按第二个维度求最大值所在的位置,7个样本,每个样本对应2个位置(3个中最大的位置)with tf.Session() as sess:    sess.run(tf.initialize_all_variables())    print sess.run(A)    print sess.run(B)    print sess.run(c)

对应的输出为

[[88 84 82 22 95 45] [59 76 73 92 59 63] [70 60 37 62 69 51] [29 50 82 21 43 49] [51 61 78 69 49 82] [95 54 14 32 80 28] [48 73 45 75 63 64]][[[88 84 82]  [22 95 45]] [[59 76 73]  [92 59 63]] [[70 60 37]  [62 69 51]] [[29 50 82]  [21 43 49]] [[51 61 78]  [69 49 82]] [[95 54 14]  [32 80 28]] [[48 73 45]  [75 63 64]]][[0 1] [1 0] [0 1] [2 2] [2 2] [0 1] [1 0]]
原创粉丝点击