tensorflow:常用API-'a'

来源:互联网 发布:网络信息保密协议书 编辑:程序博客网 时间:2024/06/05 09:35

1.加法操作
tf.accumulate_n、tf.add_n、tf.add

import tensorflow as tfa = tf.constant([[1, 2], [3, 4]])b = tf.constant([[5, 0], [0, 6]])c = tf.constant([2, 3])sess = tf.InteractiveSession()print(tf.accumulate_n([a, b]).eval())print(tf.add_n([a, b]).eval())print(tf.add(a, b).eval())# 输出的结果都是# [[ 6  2]# [ 3 10]]# 但是tf.add支持broadcastingprint(tf.add(a, c).eval())# [[6 2]# [8 4]]# print(tf.add_n([a,c]).eval()) #不支持广播-error

小结:多个tensor对应相加,推荐tf_add_n,若需要支持广播(2个shape不一样的tensor进行操作),请使用tf.add

2.argmax和argmin

a1 = tf.constant([[4, 2, 3], [1, 6, 5]])print(tf.argmax(a1, axis=0).eval())# 默认,按列取最大值的下标[0 1 1]print(tf.argmax(a1, axis=1).eval())# 按行取最大值的下标[0 1]

tf还提供了arg_max函数,其功能和argmax一样,arg_max是一个待抛弃的函数,推荐使用argmax,argmin和argmax类似

3.assign
assign对tensor的引用进行重新赋值

a2 = tf.Variable(3, dtype=tf.float32)sess.run(tf.global_variables_initializer())print(a2.eval())  # 3tf.assign(a2, 5).eval()  # 必须得eval执行下print(a2.eval())  # 5tf.assign_add(a2,2).eval() #加上一个值再从新赋值print(a2.eval())  # 7

4.as_string
类似tostring函数,不过作用在tensor上

a3 = tf.constant([[1.13, 2.02], [3.5, 4.433]])print(tf.as_string(a3,precision=2).eval()) #保留2位小数# [[b'1.13' b'2.02']# [b'3.50' b'4.43']]
原创粉丝点击