Iterative Pruning
来源:互联网 发布:linux 运维自动化 web 编辑:程序博客网 时间:2024/05/17 06:31
模型剪枝原理
文献通过修剪网络中不重要的权值,减小网络参数,主要分三个步骤:
(1)正常训练整个网络.
(2)修剪不重要的连接,一般认为权值比较大的比较重要,因此设定一个阈值,剪去小于阈值的权值.
(3)重新训练网络,采用稀疏矩阵保存网络参数.
源码分析:
以MNIST CNN 模型为例,下载github源码:
git clonehttps://github.com/garion9013/impl-pruning-TF
训练:
python train.py -2 -3
训练的时候会加载已经训练好的模型:model_ckpt_dense.
之后对全连接层fc1,fc2进行剪枝:
def apply_prune(weights): dict_nzidx = {} for target in papl.config.target_layer: wl = "w_" + target print(wl + " threshold:\t" + str(papl.config.th[wl])) # Get target layer's weights weight_obj = weights[wl] weight_arr = weight_obj.eval() # Apply pruning weight_arr, w_nzidx, w_nnz = papl.prune_dense(weight_arr, name=wl, thresh=papl.config.th[wl]) # Store pruned weights as tensorflow objects dict_nzidx[wl] = w_nzidx sess.run(weight_obj.assign(weight_arr)) return dict_nzidx
剪之后全连接层权值矩阵weight小于阈值的权值为0,保存剪枝后的模型:model_ckpt_dense_pruned.
对剪枝后的模型重新训练:
梯度的计算为,对全连接层的梯度计算为,只保留权值矩阵weight大于阈值处的梯度,
def apply_prune_on_grads(grads_and_vars, dict_nzidx): # Mask gradients with pruned elements for key, nzidx in dict_nzidx.items(): count = 0 for grad, var in grads_and_vars: if var.name == key+":0": nzidx_obj = tf.cast(tf.constant(nzidx), tf.float32) grads_and_vars[count] = (tf.multiply(nzidx_obj, grad), var) count += 1 return grads_and_vars
训练后会得到模型:model_ckpt_dense_retrained.
再次剪枝,即另全连接层的权值矩阵weight小于阈值的权值为0,得到weight1,将weight1用稀疏矩阵保存,即计算weight1中不为0的value,及其对应的index.
剪枝,
def prune_tf_sparse(weight_arr, name="None", thresh=0.005): assert isinstance(weight_arr, np.ndarray) under_threshold = abs(weight_arr) < thresh weight_arr[under_threshold] = 0 values = weight_arr[weight_arr != 0] indices = np.transpose(np.nonzero(weight_arr)) shape = list(weight_arr.shape) count = np.sum(under_threshold) print "Non-zero count (Sparse %s): %s" % (name, weight_arr.size - count) return [indices, values, shape]
获得稀疏矩阵:
def gen_sparse_dict(dense_w): sparse_w = dense_w for target in papl.config.target_all_layer: target_arr = np.transpose(dense_w[target].eval()) sparse_arr = papl.prune_tf_sparse(target_arr, name=target) sparse_w[target+"_idx"]=tf.Variable(tf.constant(sparse_arr[0],dtype=tf.int32), name=target+"_idx") sparse_w[target]=tf.Variable(tf.constant(sparse_arr[1],dtype=tf.float32), name=target) sparse_w[target+"_shape"]=tf.Variable(tf.constant(sparse_arr[2],dtype=tf.int32), name=target+"_shape") return sparse_w
之后将模型保存为model_ckpt_sparse_retrained.
迭代10次,模型有原有的13M,下降到3.9M,压缩70%,test accuracy 0.9708
测试效果:
原有模型:
python deploy_test.py -d -m model_ckpt_dense
压缩模型:
python deploy_test_pruned.py -d -m model_ckpt_sparse_retrained
- Iterative Pruning
- Iterative Deepening
- iterative-mergesort
- Iterative Quantization
- VTP Pruning-VLAN 裁剪
- vtp pruning vtp修剪
- Oracle Partition Pruning
- alpha-beta pruning
- alpha-beta pruning
- VTP(VLAN Trunking Protocol) - Pruning
- Felsenstein's tree-pruning algorithm
- [4_1_cryptcow] Search + Pruning (Unfinished)
- 树的剪枝(pruning)算法
- 19.4 Partition Pruning 分区修剪
- Pareto Ensemble Pruning(周志华)
- 分区修剪(Partition Pruning)
- Strassen Parallel Iterative
- Agile & Iterative Development (1)
- Android自定义电池控件
- Oracle分区表及分区索引
- 基于LAMP平台安装zabbix
- 停掉一台服务器,Nginx响应慢解决办法(转载)
- UGUI canvas之间触控屏蔽解决方案(解决短时间内多次点击问题)
- Iterative Pruning
- ListView多条目展示 请求网络数据 和图片
- 前端开发工具/插件(个人收集)
- C# Socket编程 一个简单的Socket 客户端服务器通信架构
- 【读书精华分享】《c++并行与分布式编程》Cameron Hughes/ Tracey Hughes著,肖和平译
- PM2
- CSS(1)
- 图文C语言指针
- 筛选法查找100以内的素数