Iterative Pruning

来源:互联网 发布:linux 运维自动化 web 编辑:程序博客网 时间:2024/05/17 06:31

模型剪枝原理

文献通过修剪网络中不重要的权值,减小网络参数,主要分三个步骤:

(1)正常训练整个网络.

(2)修剪不重要的连接,一般认为权值比较大的比较重要,因此设定一个阈值,剪去小于阈值的权值.

(3)重新训练网络,采用稀疏矩阵保存网络参数.

源码分析:

以MNIST CNN 模型为例,下载github源码:

git clonehttps://github.com/garion9013/impl-pruning-TF

训练:

python train.py -2 -3

训练的时候会加载已经训练好的模型:model_ckpt_dense.

之后对全连接层fc1,fc2进行剪枝:

def apply_prune(weights):    dict_nzidx = {}    for target in papl.config.target_layer:        wl = "w_" + target        print(wl + " threshold:\t" + str(papl.config.th[wl]))        # Get target layer's weights        weight_obj = weights[wl]        weight_arr = weight_obj.eval()        # Apply pruning        weight_arr, w_nzidx, w_nnz = papl.prune_dense(weight_arr, name=wl,                                            thresh=papl.config.th[wl])        # Store pruned weights as tensorflow objects        dict_nzidx[wl] = w_nzidx        sess.run(weight_obj.assign(weight_arr))    return dict_nzidx

剪之后全连接层权值矩阵weight小于阈值的权值为0,保存剪枝后的模型:model_ckpt_dense_pruned.

对剪枝后的模型重新训练:

梯度的计算为,对全连接层的梯度计算为,只保留权值矩阵weight大于阈值处的梯度,

def apply_prune_on_grads(grads_and_vars, dict_nzidx):    # Mask gradients with pruned elements    for key, nzidx in dict_nzidx.items():        count = 0        for grad, var in grads_and_vars:            if var.name == key+":0":                nzidx_obj = tf.cast(tf.constant(nzidx), tf.float32)                grads_and_vars[count] = (tf.multiply(nzidx_obj, grad), var)            count += 1    return grads_and_vars

训练后会得到模型:model_ckpt_dense_retrained.

再次剪枝,即另全连接层的权值矩阵weight小于阈值的权值为0,得到weight1,将weight1用稀疏矩阵保存,即计算weight1中不为0的value,及其对应的index.

剪枝,

def prune_tf_sparse(weight_arr, name="None", thresh=0.005):    assert isinstance(weight_arr, np.ndarray)    under_threshold = abs(weight_arr) < thresh    weight_arr[under_threshold] = 0    values = weight_arr[weight_arr != 0]    indices = np.transpose(np.nonzero(weight_arr))    shape = list(weight_arr.shape)    count = np.sum(under_threshold)    print "Non-zero count (Sparse %s): %s" % (name, weight_arr.size - count)    return [indices, values, shape]

获得稀疏矩阵:

def gen_sparse_dict(dense_w):    sparse_w = dense_w    for target in papl.config.target_all_layer:        target_arr = np.transpose(dense_w[target].eval())        sparse_arr = papl.prune_tf_sparse(target_arr, name=target)        sparse_w[target+"_idx"]=tf.Variable(tf.constant(sparse_arr[0],dtype=tf.int32),            name=target+"_idx")        sparse_w[target]=tf.Variable(tf.constant(sparse_arr[1],dtype=tf.float32),            name=target)        sparse_w[target+"_shape"]=tf.Variable(tf.constant(sparse_arr[2],dtype=tf.int32),            name=target+"_shape")    return sparse_w

之后将模型保存为model_ckpt_sparse_retrained.

迭代10次,模型有原有的13M,下降到3.9M,压缩70%,test accuracy 0.9708

测试效果:

原有模型:

python deploy_test.py -d -m model_ckpt_dense

压缩模型:

python deploy_test_pruned.py -d -m model_ckpt_sparse_retrained

原创粉丝点击