Tensorflow常用函数笔记
来源:互联网 发布:床垫哪个品牌好 知乎 编辑:程序博客网 时间:2024/06/06 00:09
Tensorflow常用函数笔记
tf.concat
把一组向量从某一维上拼接起来,很向numpy中的Concatenate,官网例子:
t1 = [[1, 2, 3], [4, 5, 6]]t2 = [[7, 8, 9], [10, 11, 12]]tf.concat([t1, t2], 0) ==> [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]tf.concat([t1, t2], 1) ==> [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]]# tensor t3 with shape [2, 3]# tensor t4 with shape [2, 3]tf.shape(tf.concat([t3, t4], 0)) ==> [4, 3]tf.shape(tf.concat([t3, t4], 1)) ==> [2, 6]
其实,如果是list类型的话也是可以的,只要是形似Tensor,最后tf.concat返回的还是Tensor类型
tf.gather
类似于数组的索引,可以把向量中某些索引值提取出来,得到新的向量,适用于要提取的索引为不连续的情况。这个函数似乎只适合在一维的情况下使用。
import tensorflow as tf a = tf.Variable([[1,2,3,4,5], [6,7,8,9,10], [11,12,13,14,15]])index_a = tf.Variable([0,2])b = tf.Variable([1,2,3,4,5,6,7,8,9,10])index_b = tf.Variable([2,4,6,8])with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(sess.run(tf.gather(a, index_a))) print(sess.run(tf.gather(b, index_b)))# [[ 1 2 3 4 5]# [11 12 13 14 15]]# [3 5 7 9]
tf.gather_nd
同上,但允许在多维上进行索引,例子只展示了一种很简单的用法,更复杂的用法可见官网。
import tensorflow as tf a = tf.Variable([[1,2,3,4,5], [6,7,8,9,10], [11,12,13,14,15]])index_a = tf.Variable([[0,2], [0,4], [2,2]])with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(sess.run(tf.gather_nd(a, index_a)))# [ 3 5 13]
tf.greater
判断函数。首先张量x和张量y的尺寸要相同,输出的tf.greater(x, y)也是一个和x,y尺寸相同的张量。如果x的某个元素比y中对应位置的元素大,则tf.greater(x, y)对应位置返回True,否则返回False。与此类似的函数还有tf.greater_equal。
import tensorflow as tf x = tf.Variable([[1,2,3], [6,7,8], [11,12,13]])y = tf.Variable([[0,1,2], [5,6,7], [10,11,12]])x1 = tf.Variable([[1,2,3], [6,7,8], [11,12,13]])y1 = tf.Variable([[10,1,2], [15,6,7], [10,21,12]])with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(sess.run(tf.greater(x, y))) print(sess.run(tf.greater(x1, y1)))# [[ True True True]# [ True True True]# [ True True True]]# [[False True True]# [False True True]# [ True False True]]
tf.cast
转换数据类型。
a = tf.constant([0, 2, 0, 4, 2, 2], dtype='int32')print(a)# <tf.Tensor 'Const_1:0' shape=(6,) dtype=int32>b = tf.cast(a, 'float32')print(b)# <tf.Tensor 'Cast:0' shape=(6,) dtype=float32>
tf.expand_dims & tf.squeeze
增加 / 压缩张量的维度。
a = tf.constant([0, 2, 0, 4, 2, 2], dtype='int32')print(a)# <tf.Tensor 'Const_1:0' shape=(6,) dtype=int32>b = tf.expand_dims(a, 0)print(b)# <tf.Tensor 'ExpandDims:0' shape=(1, 6) dtype=int32>print(tf.squeeze(b, 0))# <tf.Tensor 'Squeeze:0' shape=(6,) dtype=int32>
阅读全文
0 0
- tensorflow常用函数笔记
- Tensorflow常用函数笔记
- tensorflow笔记:常用函数
- 【Tensorflow】tensorflow笔记 :常用函数说明
- tensorflow笔记 :常用函数说明
- tensorflow笔记 :常用函数说明
- tensorflow笔记 :常用函数说明
- tensorflow笔记 :常用函数说明
- tensorflow笔记 :常用函数说明
- tensorflow笔记 :常用函数说明
- tensorflow笔记:常用函数说明
- tensorflow笔记 :常用函数说明
- tensorflow笔记 :常用函数说明
- tensorflow笔记 :常用函数说明
- tensorflow笔记 :常用函数说明
- tensorflow笔记 :常用函数说明
- tensorflow笔记(二) :常用函数说明
- (四) tensorflow笔记:常用函数说明
- React-native View组件transform样式
- 关于PaxCompiler字符串注意的问题
- cmd命令行下运行不了.py文件,但是编译器运行没问题,报错MoudleFoundError
- OC语言学习23-Block在类中的应用
- **
- Tensorflow常用函数笔记
- 【CQOI2016】手机号码
- A
- 记FreeCodeCamp中遇到的题目--js
- Spark-SQL之DataFrame操作大全
- 12. Servlet 页面点击计数器
- SIFT算法详解
- 有关Facebook Graph Api 中的一些笔记
- Unity_线渲染器和拖尾_024