TensorFlow计算图优化代码剖析

来源:互联网 发布:arm7也是keil编程吗 编辑:程序博客网 时间:2024/05/22 07:54

代码路径:tensorflow/core/grappler/optimizers
其中meta_optimizer.cc中的RunMetaOptimizer方法的调用触发对图的不同类型的优化操作.

优化操作分为一下几类:
1. pruning.裁剪,比如移除一些无用的操作(一旦图建立之后不再使用的stop gradient节点以及Identity节点),优化梯度计算.
2. constfold.常量打包.
3. layout. 对tensor的layout针对计算库以及设备进行调整.比如cudnn使用NCHW比较高效.
4. memory.
5. arithmetic.
6. autoparallel.
以上optimizer均可以同时使用.
下面我们对以上六种图优化手段逐一进行代码级剖析.

pruning

ModelPruner类有三个成员函数, name()方法返回名称, Optimize方法负责具体的优化操作. Feedback方法.
目的: 将所有不会被执行的节点都裁剪掉. 也就是那些不会被fanin的节点.如果没有指定fetch节点,将假设整个图都将被执行. 
不能移除必须被保留的节点(在nodes_to_prserve中);
不能移除驱动control依赖的节点;
不能移除无法确定移除后是否会新增control依赖的节点(比如,移除一个10条control edge同时驱动10条control edge,将新建100条edge);
不能移除与function链接的节点,因为会导致后面内联失败;
不能移除被其它设备驱动的节点,因为使用这些节点能够降低通信开销;
不能移除接收引用值的节点,将引用转换成非引用也不行(可能理解的不大对).

const folding

对图中常量进行合并优化.遍历图中节点,找出完全能够静态计算的节点,也就是说完全依赖于constants输入的.在CPU上将这些节点计算出来,并替换这些节点.没有CPU KERNEL的op不能进行constant folding.
辅助类:
EigenThreadPoolWrapper: 对Eigen库中threadpool进行封装,提供Schedule选取一个线程执行特定的函数.
DeviceSimple: 继承自DeviceBase

不能直接在switch节点上直接进行控制依赖的固定,因为和其它节点不同的是,执行时switch节点之后产生一个输出,并且我们必须确保控制依赖只有在对应的输出被触发时触发.我们一开始是通过查找到一个是switch节点输出节点关联的identity节点,并将它作为控制依赖的标定点.如果我们找不到这样的节点,那么就需要添加一个额外的identity节点.
辅助函数:
AddControlDependency函数:查找一个用于标定control dependency的节点,如果没有,需要添加.
MaterializeShapes: 将shape或者size或者rank操作实质化,因为计算图中tensor的这些属性是可以推导计算的.
IsFoldable: 如果一个节点的输入是空的,是不支持fold的.跳过指定preserve(白名单的除外)的节点.跳过const类型的操作,因为这些节点已经fold过了.跳过控制流节点.没必要fold没有出边的节点,除了白名单的节点.这些节点会在前期的常量folding过程中处理,如果用户想要取它们的值,那么需要保留.不能重复进行处理(fold检查并执行folding操作会出错).如果一个节点的所有输入都是静态可知,除了一种特殊情况,比如一个合并节点,只有第一个输入可用时,要求一个单独的constant输入可以被fold操作.(比较绕,具体还是建议大家看看代码).暂时不支持对string constant,可以理解为checkpoint时有bug.
EvaluateNode: 计算给定节点的输出,给定节点的输入,调用该opKernel->Compute函数.被EvaluateOneFoldable函数调用,根据输出新建一个节点.

layout

将NHWC的内存布局转换到GPU相关的操作NCHW(主要和卷积相关,cudnn使用NCHW比较高效)
辅助类:
GraphProcessor:主要提供了三个往计算图中添加const类型节点的方法(permutation/scalar/reduction,注意:三类详细的区别还不是很清楚,均需要指定device)
NodeProcessor:继承自GraphProcessor,三个成员方法updateAttrDataFormat(如果format为NHWC,那么设置为NCHW)/UpdateAttrShape(将输出的shape设置为CHW)/updateAttrSize/updateAttrStrides/updateAttrValue/updateAttrValueOfInput.用于修改输入输出数据的shape,同时更新内部数据结构属性值.AddNodeTranspose添加转换节点到计算图中.AddLayoutTransposeToInputs(调用AddNodeTranspose方法为输入添加layout transpose操作),AddLayoutTransposeToOutputs(为输出添加layout transpose).

AvgPoolGradProcessor: 继承自NodeProcessor.
BiasAddGradProcessor:继承自NodeProcessor.
Conv2DProcessor(stride为1,如果大于一,那么不进行layout转换操作,是否为有效padding)
Conv2DBackpropFilterProcessor:继承自Conv2DProcesso
Conv2DBackpropInputProcessor
FusedBatchNormGradProcessor
MaxPoolGradProcessor
AgnosticNodeProcessor
AddNProcessor,BinaryOpProcessor,ConcatProcessor,ReluGradProcessor,SliceProcessor:继承自AgnosticNodeProcessor
SliceProcessorConst:这个类主要是应对一种特殊情况,当第二三输入均为CONST时,首先会进行const folding操作,然后再进行slice优化.
SliceProcessorConcatOffset:当第二个输入为ConcatOffset.(比如inceptionV3中concat梯度计算)
SqueezeProcessor
SumProcessor
备注: conv2D,conv2DBackpropInput以及conv2DBackpropFilter,当filter size为1,或者等于输入image size时,NHWC的实现将采用特定的GEMM实现,通常来说会比NCHW的实现快.
DataLayoutOptimizer:继承自GrapProcessor.执行时需要遍历两次所有的节点shape,第一次是扩展支持NCHW的节点. 第二次是扩展layout不可知的节点.(collapse函数是为了合并所有的节点对,比如两个节点均是transpose操作,而且相反,那么可以合并.)

最后实现optimize方法,对图完成LayoutOptimizer操作.(代码实现中基于经验观察,如果引入的转换节点个数超过30个,那么不使用gemm的实现能够获得更好的性能)

Memory

主要目的: 将tensor从设备内存中换入换出.
构造函数指定优化级别(autonomy级别为memory optimizer),提到rewriterConfig,指定recomputation_targets_name_prefix以及memory_optimizer_target_node_name_prefix.

备注:这里的内存优化策略是将forward的部分op结果swap到host-memory,然后计算backward gradients时重新计算该op,达到节省显存的目的.或者标定一些节点为recomputed_node.

GetCheapToRecomputeOps方法返回一个op名称的数组,标记为这些操作为轻量级可recompute的操作.目前的实现仅仅提供一些静态的数组,后期可能会提供一个代价模型更加合理的op列表.
FindCandidateRecomputeNodes方法:找出所有feed给目标节点的recomputable ops.
connected_subgraph: 为candidateRecommputable 节点生成连接图.
GetOpGroupsToRecompute方法:基于should_recompute方法,找出几组op一起recompute.返回一个需要recompute的一组子图.
GetMaxDownstreamComponents:计算最大的拓扑数量(1,目标节点的组成,即梯度节点,feed by recomputation),(2,每个recomputed node的子重计算节点,) 当componet的数量大于这个值的时候,需要为一个重计算添加一个控制依赖.
AddRecomputeControlDependencyNodes:修改计算图,添加触发器节点,返回一个recomputed_source_nodes到trigger nodes的映射.

BuildSwapPair方法: 创建swap-in/swap-out节点对,
FindSwapTrigger方法: max_trigger_time存储了swap操作需要提前执行,将数据载入回到加速卡上,同时不影响下游计算的时间.也就是swap操作需要提前执行.

Optimize方法: 1. 找出所有_swap_to_host的节点 2. 评估每个节点需要swap的数据大小,以及传输时间(假设是基于PCIE 16GBps). 3. 遍历swap节点,找出swap trigger,找出我们需要将数据交换回来之前的节点执行,并且添加一个从这个节点到swap节点的控制依赖. 4. 将属性标记为swap_to_host的节点所有的tensor交换出去.同时添加必要的控制依赖用于延迟swap操作的执行.

备注:从代码逻辑来看,是将target节点的输入进行swap,重计算target的值.

auto parallel

主要操作: 自动并行一个图,通过将batch维度进行切分.可以理解为数据并行.
根据可用的gpu数量添加replica节点以及shared节点.

arithmetic

主要操作: 通过降低数值计算的复杂度来优化TF计算.
对数值表达式进行简化,移除冗余计算,表达式替换等手段.

以上内容仅仅很粗略的阅读了一些代码,后面会不断细化.