TensorFlow Python API解析:图的核心数据结构

来源:互联网 发布:js 微信打开链接 编辑:程序博客网 时间:2024/04/29 17:34

转自http://blog.csdn.net/deepknower/article/details/53824653

本文默认的一些名称叫法
1. “Tensor实例”、”Tensor对象”、”tensor”都是在说一个Tensor实例。同样,”Operation实例”、”Operation对象”、”operation”都是在说一个Operation实例。
2. 单独的”Tensor”和”Operation”都表示一个Python class
3. “operation”有时简写为”op”或”Op”,”tensor”有时简写为”t”

代码位置:tensorflow/tensorflow/python/framework/ops.py

一、Graph类

1. 要点

  1. TensorFlow中的计算,表示为一个数据流图,简称“图”
  2. 一个Graph实例就是一个图,由一组Operation对象和Tensor对象构成:每个Operation对象(简记为op)表示最小的计算单元,每个Tensor对象表示在operations间传递的基本数据单元
  3. 如果你没有注册自己的图,系统会提供一个默认图。你可通过调用tf.get_default_graph()显式地访问这个图,也可以不理会这个图,因为调用任一个operation函数时,如调用constant op,c=tf.constant(4.0),一个表示operation的节点会自动添加到这个图上,此时c.graph就指这个默认图。
  4. 如果我们创建了一个Graph实例,并想用它取代上面的默认图,把它指定为一个新的默认图,至少是临时换一下,可以调用该Graph实例的as_default()方法,并得到一个Python中的上下文管理器(context manager),来管理临时默认图的生命周期,即with ...下的代码区域。
g = tf.Graph()with g.as_default():  # 此时定义的operation和tensor都自动添加到图g上  c = tf.constant(30.0)
  • 1
  • 2
  • 3
  • 4

小提示:组装图阶段,Graph类不是线程安全的,添加operations最好在单线程内完成

2. Graph的属性

内部属性

  • 与operation相关:
    • _nodes_by_id:dict( op的id => op ),按id记录所有添加到图上的op
    • _nodes_by_name:dict( op的name => op ),按名字记录所有添加到图上的op
    • _next_id_counter:int,自增器,创建下一个op时用的id
    • _version:int,记录所有op中最大的id
    • _default_original_op:有些op需要附带一个original_op,如replica op需要指出它要对哪个op进行复制
    • _attr_scope_map:dict( name scope => attr ),用于添加一组额外的属性到指定scope中的所有op
    • _op_to_kernel_label_map:dict( op type => kernel label ),kernel可能是指operation中更底层的实现
    • _gradient_override_map:dict( op type => 另一个op type ),把一个含自定义gradient函数的注册op,用在一个已存在的op上
  • 与命名域name scope相关:
    • _name_stack:字符串,嵌套的各个scopes的名字拼成的栈,用带间隔符”/”的字符串表示
    • _names_in_use: dict( name scope => 使用次数 )
  • 与device相关:
    • _device_function_stack: list,用来选择device的函数栈,每个元素是一个device_function(op),用来获取op所在device
  • 与控制流相关:
    • _control_flow_context:一个context对象,表示当前控制流的上下文,如CondContext对象,WhileContext对象,定义在ops/control_flow_ops.py。实际上,控制流也是一个op,用来控制其他op的执行,添加一些条件依赖的关系到图中,使执行某个operation前先查看依赖
    • _control_dependencies_stack:list,一个控制器栈,每个控制器是一个上下文,存有控制依赖信息,表明当执行完依赖中的operations和tensors后,才能执行此上下文中的operations
  • 与feed和fetch相关:
    • _unfeedable_tensors:set,定义不能feed的tensors
    • _unfetchable_ops:set,定义不能fetch的ops
    • _handle_feeders:dict( tensor handle placeholder => tensor dtype )
    • _handle_readers:dict( tensor handle => 它的read op )
    • _handle_movers:dict( tensor handle => 它的move op )
    • _handle_deleters:dict( tensor handle => 它delete op )
  • 图需要:
    • _seed:当前图内使用的随机种子
    • _collections:dict( collection name => collection ),相当于图中的一块缓存,每个collection可看成一个list,可以存任何对象
    • _functions:定义图内使用中的一些函数
    • _container:资源容器resource container,用来存储跟踪stateful operations,如:variables,queues
    • _registered_ops:注册的所有操作
  • 程序运行需要:
    • _finalized:布尔值,真表示Graph属性都已确定,不再做修改
    • _lock:保证读取Graph某些属性(如:_version)时尽可能线程安全
  • TensorFlow框架需要:
    • _graph_def_version:图定义的版本
  • 其他:
    • _building_function:该图是否表示一个函数
    • _colocation_stack:保存共位设置(其他op都与指定op共位)的栈

对外属性

  • tf.Graph.version,也就是self._version,记录最新的节点version,即图中最大op id,但是与GraphDef的version无关
  • tf.Graph.graph_def_versions,也就是self._graph_def_versions,GraphDef版本,定义在tensorflow/tensorflow/core/framework/graph.proto
  • tf.Graph.seed,也就是self._seed,此图内使用的随机种子
  • tf.Graph.building_function,也就是self._building_function
  • tf.Graph.finalized,也就是self._finalized,表明组装图阶段是否完成

3. Graph的主要方法

构造方法

  • tf.Graph.__init__():创建一个空图

获取图元素(tensors, operations)的方法

  • tf.Graph.as_graph_def(from_version=None, add_shapes=False):返回该graph对应的GraphDef表示,使用了protocol buffer,见下面的message GraphDef。该方法是线程安全的。
    • 传参:(1) from_version表明包括的节点version(即op id)的范围,from_version之前的节点都不要;(2)add_shapes如果为真,则每个节点都要添加输出tensors的形状信息到_output_shapes
messsage GraphDef {  // 图中的所有节点,参见下面message NodeDef  repeated NodeDef node = 1;  // 图的版本,不同于TensorFlow版本  VersionDef versions = 4;  // 丢弃  int32 version = 3 [deprecated = true];  // Experimental. 提供用户自定义的函数  FunctionDefLibrary library = 2;}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • tf.Graph.as_graph_element(obj, allow_tensor=True, allow_operation=True):该获取信息的方法实际上完成了一个验证加转换的工作,给定一个obj,看它能否对应到图中的元素,可以是一个operation,也可以是一个tensor,如果对应,则以operation或tensor的身份返回它自己。该方法可以被多个线程同时调用。
    • 传参:(1) obj:可以是一个Tensor对象,或一个Operation对象,或tensor名,或operation名,或其他对象;(2)allow_tensor:真表示obj可以是tensor;(3) allow_operation:真表示obj可以是operation
  • get系列方法,可被多个线程同时调用:
    • tf.Graph.get_operation_by_name(name):根据名字获取某个operation
    • tf.Graph.get_tensor_by_name(name):根据名字获取某个tensor
    • tf.Graph.get_operations():获取所有operations
  • 判断是否可feed或可fetch
    • tf.Graph.is_feedable(tensor)
    • tf.Graph.is_fetchable(tensor_or_op)
  • 设置不可feed或不可fetch
    • tf.Graph.prevent_feeding(tensor)
    • tf.Graph.prevent_fetching(op)

添加节点组装图的方法

  • tf.Graph.unique_name(name, mark_as_used=True):为operation name构造一个唯一名,唯一名可能包含分隔符"/"。传参mark_as_used表示构造的唯一名,只是用来看看,还是要被创建出来使用的,将其传给方法create_op()来创建一个operation。
  • tf.Graph.create_op(op_type, inputs, dtypes, ...):这是一个低级别接口,开发者一般用不到,因为只用具体op的构造函数,如tf.constant(),即可实现向图添加op节点。此方法返回一个Operation实例,却有很多传入参数,包括:
    • 必填的有:(1) op_type:创建的op类型,也就是操作方法名,如”MatMul”,对应OpDef.name字段;(2) inputs:op的输入,是一个由Tensor对象组成的列表;(3) dtypes:op输出的tensors的数据类型,是一个由DType对象组成的列表
    • 可选的有:(1) input_types:op输入的tensors的类型,是一个由DType对象组成的列表,默认使用inputs中的tensors自带的Dtype;(2)name:op做节点的名字,默认基于op_type构造出;(3) attrs:dict( 属性名 => operation的属性),在NodeDef proto中有定义;(4)op_def:OpDef proto,是一个描述operation操作方法的protocol buffer;(5) compute_shapes:布尔值,是否计算输出的tensors的形状;(6)compute_device:布尔值,是否执行device_function来获取operation的device

结束组装图的定稿方法

  • tf.Graph.finalize():结束组装图,以后图只能读不能写,不能再添加新operation节点。此方法用在图要在多个线程间共享的场景下,如用于QueueRunner

切换默认图的方法

  • tf.Graph.as_default():让当前图取代默认图。一般来说,不常使用,因为当前线程会自动提供一个全局默认图,也就是说,全局默认图是当前线程的一个属性,新建一个线程后全局默认图就变了。除非你在同一个进程内创建了多个图,才会有用这个方法的需求。
    • 返回:一个context manager,实际上当前设为默认的graph就是一个上下文,在它的代码块内执行的op都会添加到该图上
# 两种等价方法# 方法一:g = tf.Graph()with g.as_default():  ...# 方法二:with tf.Graph().as_default() as g: # 当前上下文就是刚创建的g  ...
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

返回新上下文的方法

  • tf.Graph.control_dependencies(control_inputs):返回一个控制依赖的上下文,使得上下文内的新加入op都有此依赖。控制依赖的意思是,若想执行下步的op,必须先完成依赖中的input。此方法的上下文就是一个控制器controller,controller内保存了新建的控制依赖,同时controller加入一个控制器栈controller stack,以支持嵌套的控制依赖上下文。
    • 传参:control_inputs是一个operation或tensor的列表
with g.control_dependencies([a, b]):  # 这里新建的op在a,b后执行  with g.control_dependencies([c, d]):    # 这里新建的op在a,b,c,d后执行    with g.control_dependencies(None):      # 因为依赖链断掉,这里新建的op不需等待a,b,c,d      with.control_dependencies([e, f])        # 这里新建的op在e,f后执行
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • tf.Graph.device(device_name_or_function):返回一个默认device的上下文,使得上下文内的新加入op都被分派到该device上。此方法的上下文就是一个device function,它被压入一个device function stack,以支持嵌套。
    • 传参:device_name_or_function可以是一个表示device名的字符串,或一个返回device名的函数,或None
    • 特例:无论位于哪个device上下文中,variable assignment op v.assign()将随它的Variable对象v放在一起
def matmul_on_gpu(node):  if node.type == "MatMul": # node.type就是指op type    return "/gpu:0"  else:    return "/cpu:0"with g.device(matmul_on_gpu):  # 此处新建的所有type为"MatMul"的op将放在GPU 0上,其他则放在CPU 0上  with g.device('/gpu:0'):    # 此处新建的所有op都将放在GPU 0上
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

device的命名格式:
/job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num>
- <name>:标识id,为一个字符串,形如[a-zA-Z][_a-zA-Z]*,比如/job:param_server为一个名为”param_server”的job
- <type>:支持的设备类型,如”cpu”或”gpu”
- <replica>,<task>,<device_num>:小的非负整数
举例:
(1) /job:w/replica:0/task:0/device:gpu:*:位于job w的replica 0, task 0上的任何gpu devices
(2) /job:*/replica:*/task:*/device:cpu:*:位于任何job/task/replica的任何cpu devices

  • tf.Graph.name_scope(name):返回一个层级命名operation的上下文。一个图维护一个命名域(或叫“命名空间”)的栈self._name_stack,此方法的传入参数name会被压入该栈,支持嵌套。
    • 参数:name可以是一个字符串,用于创建一个新的name scope;也可以是一个已有的name scope,用于重新进入这个已存在的scope;也可以是None或空字符,此时表示顶层的name scope
c = tf.constant(1.0, name="c") # c.op.name为"c"c_1 = tf.constant(2.0, name="c") # c_1.op.name为"c_1"with g.name_scope("nested") as scope:  nested_c = tf.constant(3.0, name="c") # nested_c.op.name为"nested/c"  with g.name_scope("inner"):    nested_inner_c = tf.constant(4.0, name="c") # nested_inner_c.op.name为"nested/inner/c"  with g.name_scope("inner"): # 因为此域下已有"inner",所以用"inner_1"    nested_inner_1_c = tf.constant(5.0, name="c") # nested_inner_1_c.op.name为"nested/inner_1/c"    with g.name_scope(scope): # 无论现在嵌套哪里,都转换成scope,即"nested/"      nested_c_1 = tf.constant(6.0, name="c") # nested_c_1.op.name为"nested/c_1"      with g.name_scope(""): # 变成顶级scope        c_2 = tf.constant(7.0, name="c") # c_2.op.name为"c_2"
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • tf.Graph.gradient_override_map(op_type_map):返回一个改写gradient函数的上下文,使得针对某些operation,我们可以使用自己的gradient函数。一个图维护这样一个映射关系self._gradient_override_map.
# 先注册一个gradient函数@tf.RegisterGradient("CustomSquare")def _custom_square_grad(op, grad):  # ...with tf.Graph().as_default() as g:  c = tf.constant(5.0)  s_1 = tf.square(c) # 使用tf.square默认的gradient  with g.gradient_override_map({"Sqaure": "CustomSquare"}):    s_2 = tf.square(s_2): # 使用自定义的_custom_square_grad函数来计算s_2的梯度
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • tf.Graph.colocate_with(op, ignore_existing=False):返回一个共用给定op的位置的上下文,使得上下文内的新加入op都共用这个位置。传参ignore_existing为真则表示忽略以前所有的共位设置。
a = tf.Variable([1.0])with g.colocate_with(a):  # 下面的b,c与a共位  b = tf.constant(1.0)  c = tf.add(a, b)
  • 1
  • 2
  • 3
  • 4
  • 5
  • tf.Graph.container(container_name):返回一个带资源容器的上下文,服务于带状态的operations,如:varaibles,queues,用来存储跟踪它们的状态。可使用tf.Session.reset()清除资源容器中保存的信息。
with g.container('experiment0'):  v = tf.Variable([1.0]) # 将存到资源容器"experiment0"  with g.container('experiment1'):    q = tf.FIFOQueue(10, tf.float32) # 将存到资源容器"experiment1"  with g.container(''):    v2 = tf.Variable([2.0]) # 将存到默认的资源容器
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

与graph collections相关的方法

一个图中可以有多个collections,也称为graph collections,每个collection都有一个名字,用来存储一组相关的对象,可以把一个collection看成一个list或array。它的标准名字有:GLOBAL_VARIABLESLOCAL_VARIABLESMODEL_VARIABLES等,定义在GraphKeys类里。

  • 存value:
    • tf.Graph.add_to_collection(name, value):把value存到名为name的collection中
    • tf.Graph.add_to_collections(names, value):把value存到名字在names上的所有collections中
  • 取value:
    • tf.Graph.get_collection(name, scope=None):根据名为name的collection,返回它中所有的values,即一个values的列表。传参scope充当一个过滤器,筛出指定scope中的values。
    • tf.Graph.get_collection_ref(name):同上,不同之处是此方法返回对collection本身的引用,而不是复制一份,故在上面的修改会起作用。
    • tf.Graph.get_all_collection_keys():返回一个collections的list
  • 清除value:
    • f.Graph.clear_collection(name):清除名为name的collection上所有的values

二、Operation类

1. 要点

  1. Operation实例就是数据流图中的节点负责tensors的计算,即输入是若干Tensor实例,输出也是若干Tensor实例。
  2. Operation实例与实例的type之间的区别:
    • 这里的Operation实例,也就是operation或op,与我们想的加法、减法等操作在概念上有略微差异,后者侧重于对方法的描述,前者则参与到图中,作为一个节点,叫”operator”更合适
    • 实例的type,也就是op type,才是指像加法、减法这样的操作方法,如“MatMul”表示矩阵乘这个操作方法
    • 每个Operation实例的名字在图中都是唯一的,因为对应一个特定节点,但是相同的操作方法op type在图中可以有多个。它们在protocol buffers分别被定义为NodeDef和OpDef
  3. 创建一个Operation实例有两种方法:
    • 第一种:调用一个op构造函数,如c=tf.matmul(a,b),则创建一个表示矩阵乘操作的op节点,其中a, b, c都为tensor,a和b作输入,c作输出
    • 第二种:调用方法Graph.create_op()
  4. 启动一个session后,执行Operation实例也有两种方法:
    • 第一种:把该op传入session的方法run()
    • 第二种:直接调用op.run(),这实际上是tf.get_default_session().run(op)的简写

2. Operation的属性

  • 内部属性:
    • _node_def为一个NodeDef对象,_op_def为一个OpDef对象
    • _id_value为op在图中的id
    • _graph为op所在图
    • _inputs为输入op的tensors列表,_outputs为输出op的tensors列表
    • _input_types为输入op的tensors的数据类型列表,_output_types为输出op的tensors的数据类型列表
    • _control_inputs为执行op前的控制依赖
    • _original_op为当前op需要的一个原op,如replica op还需要一个op,称为原op
    • _traceback为创建op时的调用栈call stack
    • _control_flow_context为包含当前op的当前控制流上下文
  • 对外属性:
    • tf.Operation.name:该operation的全名
    • tf.Operation.type:该operation的type,如MatMul
    • tf.Operation.inputs:该operation的输入,是一个Tensor对象的列表
    • tf.Operation.outputs:该operation的输出,也是一个Tensor对象的列表
    • tf.Operation.control_inputs:是一个Operation对象的列表,执行当前operation前,需要保证此列表中的所有Operation对象都已执行完毕
    • tf.Operation.graph:该operation所在的graph
    • tf.Operation.device:该operation所在的device,表示为一个字符串
    • tf.Operation.traceback:自该operation创建以来的调用栈
    • tf.Operation.node_def:该operation对应的NodeDef表示,使用了protocol buffer,见下面的message NodeDef
    • tf.Operation.op_def:该operation的type对应的OpDef表示,使用了protocol buffer,见下面的message OpDef
// 定义图中的一个节点message NodeDef {  // 本operator名,在一个图中是唯一的,可看成当前节点名  string name = 1;  // operation名,在一个图中可有重复,可看成操作的方法名  string op = 2;  // input列表,每个input表示为字符串"<node>:<src_output>",表明来自哪个op的哪个output索引  repeated string input = 3;  // 本节点所在device,举例为:  // 1) "@other/node":与另一个节点"other/node"共位置  // 2) "/job:worker/replica:0/task:1/gpu:3":全路径  // 3) "/job:worker/gpu:3":部分路径  // 4) "":无  string device = 4;  // 应包含OpDef的所有attrs  map<string, AttrValue> attr = 5;}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
// 定义一个操作message OpDef {  // operation名,等于NodeDef中的"op",名字采用CameCase格式,若首字符为"_"则为内部保留的操作  string name = 1;  // 定义一个argument message,用作一个input或output  message ArgDef {    // 当前input或output的名    sting name = 1;    // 给人读的描述    string description = 2;    // 当前input或output,可以接受一到多个tensors    // (1)当接受一个tensors时,要么设置type字段,要么设置type_attr字段,指向一个类型为"type"的attr    // (2)当接受多个type相同的tensors时,要设置number_attr字段,指向一个类型为"int"的attr,表示tensors的数目,当然还要设置type或type_attr字段    // (3)当接受多个type不同的tensors时,要设置type_list_attr字段,指向一个类型为"list(type)"的attr,不用设置type、type_attr和number_attr字段    DataType type = 3;    string type_attr = 4;    string number_attr = 5;    string type_list_attr = 5;    // 当前input或output是否为ref    bool is_ref = 16;  }  // 当前操作的所有输入  repeated ArgDef input_arg = 2;  // 当前操作的所有输出  repeated ArgDef output_arg = 3;  // 定义一个attr message  message AttrDef {    string name = 1;    string type = 2; // 如:"string", "list(string)", "int"    AttrValue default_value = 3;    string description = 4;    // 对于"int"型,有下面两字段    bool has_minimum = 5;    int64 minimum = 6;    AttrValue allowed_values = 7;  }  // 在op中定义的attr会加入NodeDef  repeated AttrDef attr = 4;  // 其他  OpDeprecation deprecation = 8;  string summary = 5;  string description = 6;  // 操作是否满足交换律  bool is_commutative = 18;  // 操作可接受2个以上的inputs,得出1个同类型的output,需满足交换律和结合律  bool is_aggregate = 16;  // 操作是否带状态,stateful ops不能在devices间移动,除非状态也能移动  bool is_stateful = 17; // 如:variables, queue  // 默认情况下,所有op的inputs必须是初始化后的tensors  bool allows_uninitialized_input = 19; // 如:assign}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54

3. Operation的主要方法

构造方法

  • tf.Operation.__init__(node_def, g, inputs=None, ...):创建一个Operation实例,传入参数有很多,包括:
    • 必填参数:(1) node_def为一个node_def_pb2.NodeDef实例,包含了描述operation的属性,有nameopdevice但没有input,因为input是生成模型时才有的;(2)g为所在的图
    • 可选参数:(1) inputs为当前operation的输入,是一个Tensor对象的列表;(2) output_types当前operation的输出的类型,为是一个DType对象的列表;(3)control_inputs执行当前operation的前提,为一个operations或tensors的列表;(4) input_types为输入的类型,默认为[x.dtype.base_dtype for x in inputs];(5)original_op为一个关联的原op,如复制op的replica op,要给出那个op;(6) op_def为当前operation代表的op type,如”matmul”,定义在op_def_pb2.OpDef

执行operation的方法

  • tf.Operation.run(feed_dict=None, session=None):在session中运行当前operation,会一级级触发那些给当前operation提供inputs的所有直接或间接的operations。实际上,最后调用的是session.run(operation, feed_dict)
    • 传参:(1) feed_dict是一个dict( Tensor对象或tensor名 => 具体值 ),具体值可以是list、numpy ndarray、TensorProto或string;(2)session若没指定,则用当前线程的默认session。

获取operation信息的方法

  • tf.Operation.get_attr(name):根据名字返回当前operation的某个属性值。一个operation会有多个属性,定义在self._node_def.attr上。
  • tf.Operation.colocation_groups():返回当前operation的共位置组列表,格式为["loc:@<节点名即_node_def.name>", ...]

三、Tensor类

1. 要点

  1. Tensor对象作为表示数量的符号,要参与到数学计算中,就要重载Python的许多操作符
  2. Tensor对象Operation对象一起构建了图,如果说operation是节点,tensor更像是边,把不同的operations链接在一起
    • 一般来说,operation的输入除了tensor,还可以是其他类型,只要有能转化为tensor的相应支持,但是operation的输出只能是tensor。而且,任何一个Tensor实例都对应一个operation,作它的一个输出,所以tensor的创建离不开operation的创建,但Tensor实例并不保留operation的输出数值,而是提供一种计算这些数值的通道,用在session中
    • Tensor实例自创建就是某个operation的一个output,但它也可以是其他operation的一个input。把tensor传给其他operation作输入的过程,就是在operations间建立连接的过程,就是组网的过程.
  3. Tensor对象是符号,不是具体值
    • 启动session后,想得到tensor的具体数值,需要调用Session.run()t.eval()来计算t.eval()实际上是tf.get_default_session().run(t)
# 组网过程,下面的c、d、e都是Tensor实例,即一个符号,而不是具体值c = tf.constant([[1.0, 2.0], [3.0, 4.0]]) # constant op的输入是一个二维数组常量d = tf.constant([[1.0, 1.0], [0.0, 1.0]])e = tf.matmul(c, d)# 启动session来执行图sess = tf.Session()result = sess.run(e) # 这里result是一个numpy array,负责存储具体值
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

2. Tensor的属性

  • 内部属性:
    • _op:以该tensor作输出的op
    • _value_index:在op输出中的索引
    • _dtype:数据类型
    • _shape:tensor形状
    • _consumers:使用该input做输入的operations列表,方便在图中游走
    • _handle_shape_handle_dtype:用于C++形状推断
  • 对外属性:
    • tf.Tensor.dtype:该tensor中元素的DType
    • tf.Tensor.name:该tensor的名字,为”op名:输出索引”
    • tf.Tensor.op:该tensor所在的operation,tensor作它的输出
    • tf.Tensor.value_index:该tensor位于它所在operation的输出列表中的索引
    • tf.Tensor.graph:该tensor所在的图,也是它的op的图
    • tf.Tensor.device:该tensor所在的device
    • tf.Tensor.shape:该tensor的形状,是一个TensorShape对象,如TensorShape([Dimension(3), Dimension(4)])

3. Tensor的主要方法

  • 构造方法
    • tf.Tensor.__init__(op, value_index, dtype):创建一个Tensor实例,op为以它做输出的operation,value_index为它在输出中的索引,dtype为它的元素数据类型
  • 求tensor值的方法
    • tf.Tensor.eval(feed_dict=None, session=None):启动session后,如果session的图与该tensor的图相同,则在session中对该tensor求值,会触发计算它的operation以及图中所依赖的前面operations。实际上,最终调用的是session.run(tensors, feed_dict)
      • 传参:(1) feed_dict是一个dict( Tensor对象或tensor名 => 具体值 ),具体值可以是list、numpy ndarray、TensorProto或string;(2)session若没指定,则用当前线程的默认session。
      • 返回:一个numpy array
  • 获取和设置tensor形状信息的方法
    • tf.Tensor.get_shape():获取该tensor的形状,返回是一个TensorShape对象,推断形状(shape inference)的过程不用启动session,但在operation中需注册一个用于推断形状的函数。比如,c=tf.constant([[1.0,2.0,3.0],[4.0,5.0,6.0]]),则调用c.get_shape()得到一个TensorShape([Dimension(2), Dimension(3)])
    • tf.Tensor.set_shape(shape):设置或更新该tensor的形状,如image.set_shape([28, 28, 3])
  • 获取tensor作输入的operations信息的方法
    • tf.Tensor.consumers():返回以该tensor做输入的所有operations

4. Tensor重载的Python operators

  • 算术操作:__(r)add__“+”,__(r)sub__“-“,__(r)mul__“*”,__(r)div__“/”,__(r)floordiv__“//”,__(r)truediv“/”,__(r)mod__“mod”,__neg__“-“,__(r)pow__“pow(x,y)”,__abs__“| |”
  • 逻辑操作:__(r)and__“&”,__(r)or__“|”,__invert__“~”,__(r)xor__“^”
  • 比较操作:__eq__“==”,__ge__“>=”,__gt__“>”,__le__“<=”,__lt__“<”
  • 其他:__getitem__“[ ]”,__hash__

默认都是元素级(element-wise)操作,除了__(r)mul__,源码注释这样说:# Dispatches cwise mul for "DenseDense" and "DenseSparse"

# __getitem__:限定到子tensorfoo = tf.constant([[1,2,3], [4,5,6], [7,8,9]])print(foo[::2, ::-1].eval()) # => [[3,2,1], [9,8,7]]