Channel Pruning for Accelerating Very Deep Neural Networks代码详解

来源:互联网 发布:哪个软件可以看陆小凤 编辑:程序博客网 时间:2024/06/10 14:22

完整的剪枝过程主要包括以下几个步骤:
1、权重分解以及channel pruning
2、对剪枝后的模型fine_tune,恢复表征能力
3、将微调后的caffemodel作为faster的pre_trained模型,再次进行fine_tune

以下对主要的函数进行讲解,文章的后面会对部分小函数进行解释。


网络裁剪

通过以下指令进行pruning:
./run.sh python3 train.py -action c3 -caffe [GPU0]

if __name__ == '__main__':    args = parse_args()    cfgs.set_nBatches(dcfgs.nBatches)    dcfgs.dic.option=1    DEBUG = 1    if args.action == cfgs.Action.c3:        c3()#调用这个函数

c3函数是控制整个流程的函数:

def c3(pt=cfgs.vgg.model,model=cfgs.vgg.weights):     dcfgs.splitconvrelu=True    cfgs.accname='accuracy@5'     def solve(pt, model):        net = Net(pt, model=model)        net.load_frozen() # step1生成的frozen文件中加载        '''调用剪枝的R3主函数,进行空间分解、通道分解、通道剪枝,返回修建后的权重,以及剪枝后的网络结构'''        WPQ, new_pt = net.R3()        return {"WPQ": WPQ, "new_pt": new_pt}    def stepend(new_pt, model, WPQ):        net = Net(new_pt, model=model)        net.WPQ = WPQ        net.finalmodel(save=False) #将weight加载进WPQ        net.dis_memory()        new_pt, new_model = net.save(prefix='3c')#存为3c为前缀的新模型        print('caffe test -model',new_pt, '-weights',new_model)        return {"final": None}    worker = Worker()    outputs = worker.do(step0, pt=pt, model=model)#执行step0函数    pt = outputs['pt']    outputs = worker.do(step1,**outputs)#执行step1函数    outputs['pt'] = mem_pt(pt)    outputs = worker.do(solve, **outputs)#执行solve函数    printstage("saving")    outputs = worker.do(stepend, model=model, **outputs)#执行stepend函数

Step 0 网络预处理

def step0(pt, model):    net = Net(pt, model=model, noTF=1)      # WPQ存的是剪枝后的权重,稍后会存在caffemodel中    WPQ, pt, model = net.preprocess_resnet() #对relu bathnorm scale层进行处理    return {"WPQ": WPQ, "pt": pt, "model": model}

Step 1 数据准备

def step1(pt, model, WPQ, check_exist=False):    net = Net(pt, model, noTF=1)    model = net.finalmodel(WPQ) #将WPQ加载进model中(后续有函数详解)    convs = net.convs  #convs等于proto中所有conv的name list  不去掉最后一层conv    #在整个数据集中选取batch=5000张图,每个特征图中只选取10个点来计算    #将这50000个点对应的特征图值 以及通道数  做成freeze.pikcle的内存形式    net.freeze_images(check_exist=check_exist, convs=convs)    return {"model":model}

Step 2 solve 网络剪枝主函数

def solve(pt, model):    net = Net(pt, model=model)    net.load_frozen() # step1生成的frozen文件中加载    '''调用剪枝的R3主函数,进行空间分解、通道分解、通道剪枝,返回修建后的权重,以及剪枝后的网络结构'''    WPQ, new_pt = net.R3()    return {"WPQ": WPQ, "new_pt": new_pt}

Step 3 stepend 保存最终的caffemodel

def stepend(new_pt, model, WPQ):    net = Net(new_pt, model=model)    net.WPQ = WPQ    net.finalmodel(save=False) #将weight加载进WPQ    net.dis_memory()    new_pt, new_model = net.save(prefix='3c')#存为3c为前缀的新模型    print('caffe test -model',new_pt, '-weights',new_model)    return {"final": None}

上述函数的具体实现

R3函数

被step3的solve函数调用,所有空间分解、通道分解、通道剪枝的函数都放在R3函数中:

def R3(self):        speed_ratio = dcfgs.dic.keep  #3  也就是3倍加速       prefix += str(int(speed_ratio)+1)+'x'       DEBUG = True       convs= self.convs       self.WPQ = dict()       self.selection = dict()       self._mem = True       end = 5 #即只考虑[1,5)组卷积组       alldic = ['conv%d_1' % i for i in range(1,end)] + ['conv%d_2' % i for i in range(3, end)]       pooldic = ['conv1_2', 'conv2_2']#, 'conv3_3']  #后面有pool的conv层       #经验的要保留的通道数目        #原始    faster    3c VGG       rankdic = {'conv1_1': 17,  #64      #12      #64                  'conv1_2': 17,  #64      #64      #V22 H22 P58                  'conv2_1': 37,  #128     #21      #V49 H49 P117                  'conv2_2': 47,  #128     #128     #V62 H62 P117                  'conv3_1': 83,  #256     #73      #V110 H110 P237                  'conv3_2': 89,  #256     #58      #V118 H118 P242                  'conv3_3': 106, #256     #256     #V114 H114 P256                  'conv4_1': 175, #512     #121     #V233 H233 P475                  'conv4_2': 192, #512     #166     #V256 H256 P457                  'conv4_3': 227, #512     #512     #V302 H302 P512                  'conv5_1': 398, #512     #512     #V398 H398 P512                  'conv5_2': 390, #512     #512     #V390 H390 P512                  'conv5_3': 379} #512     #512     #V379 H379 P512       c_ratio = 1.15       def getX(name):           x = self.extract_XY(self.bottom_names[name][0], name)           return np.rollaxis(x.reshape((-1, 3, 3, x.shape[1])), 3, 1).copy()       def setConv(c, d):           if c in self.selection:               self.param_data(c)[:,self.selection[c],:,:] = d           else:               self.set_param_data(c, d)       t = Timer()       #zip  将两个参数按元祖组合成list 返回    所以得到的是类似[(conv1,conv2),(conv2,conv3),......(conv5,pool5)]       #conv, convnext分别代表这一层conv,和其下一层conv或pool       for conv, convnext in zip(convs[1:], convs[2:]+['pool5']):            #因为空间分解、通道分解会产生新的conv_V\H\P层,代替原来的conv层           conv_V = underline(conv, 'V')   #conv后加下划线V                                  conv_H = underline(conv, 'H')                                    conv_P = underline(conv, 'P')           W_shape = self.param_shape(conv)           d_c = int(W_shape[0] / c_ratio)           rank = rankdic[conv]           d_prime = rank           if d_c < rank: d_c = rank           '''spatial decomposition空间分解'''           if True:               t.tic()               weights = self.param_data(conv)               if conv in self.selection:                   weights = weights[:,self.selection[conv],:,:]               if 1:                   Y = self._feats_dict[conv] - self.param_b_data(conv)                   X = getX(conv)                   if conv in self.selection:                       X = X[:,self.selection[conv],:,:]                   #执行的是SVD分解,左奇异值是V,右奇异值*对角是H                   #V变成了3*1的低维卷积层,H变成了相同维数的1*3的卷积层                   #VHr是降维后的矩阵                   V, H, VHr, b = VH_decompose(weights, rank=rank, DEBUG=DEBUG, X=X, Y=Y)                   self.set_param_b(conv,b)               self.WPQ[conv_V] = V  #V作为新的层conv_V的参数               setConv(conv,VHr)#将conv层的w设置为低秩的VHr               #H作为conv_H层的w,conv的b作为conv_H的b               self.WPQ[(conv_H, 0)] = H               self.WPQ[(conv_H, 1)] = self.param_b_data(conv)               t.toc('spatial_decomposition')#记录空间分解的时间           self.insert(conv, conv_H)#在conv层后面插入conv_H           '''channel decomposition通道分解ITQ_decompose'''           if True:               t.tic()               feats_dict, _ = self.extract_features(names=conv, points_dict=self._points_dict, save=1)               Y = feats_dict[conv]               W1, W2, B, W12 = ITQ_decompose(Y, self._feats_dict[conv], H, d_prime, bias=self.param_b_data(conv), DEBUG=0, Wr=VHr)               # set W to low rank W, asymetric solver               setConv(conv,W12.copy())               self.set_param_b(conv, B.copy())               # save W_prime and P params               W_prime_shape = [d_prime, H.shape[1], H.shape[2], H.shape[3]]               P_shape = [W2.shape[0], W2.shape[1], 1, 1]               self.WPQ[(conv_H, 0)] = W1.reshape(W_prime_shape)               self.WPQ[(conv_H, 1)] = np.zeros(d_prime)               self.WPQ[(conv_P, 0)] = W2.reshape(P_shape)               self.WPQ[(conv_P, 1)] = B               self.insert(conv_H, conv_P, pad=0, kernel_size=1, bias=True, stride=1)               t.toc('channel_decomposition')           '''channel pruning通道剪枝'''           if dcfgs.dic.vh and (conv in alldic or conv in pooldic) and (convnext in self.convs):               t.tic()               if conv in pooldic:                   X_name = self.bottom_names[convnext][0]               else:                   X_name = conv              '''dictionary_kernel是对X_name卷积层进行剪枝,剪枝后不为0的通道索引为idxs,W2,B2为conv_next剪枝调整后的Wb'''               idxs, W2, B2 = self.dictionary_kernel(X_name, None, d_c, convnext, None)               self.selection[convnext] = idxs  #记录next层的剪枝索引               self.param_data(convnext)[:, ~idxs, ...] = 0#将其它置为0               self.param_data(convnext)[:, idxs, ...] = W2.copy()#将不为0的index置为调整后的W2               self.set_param_b(convnext,B2)#将b置为调整后的b2               if (conv_P,0) in self.WPQ:                   key =  conv_P               else:                   key = conv_H               self.WPQ[(key,0)] = self.WPQ[(key,0)][idxs]#WPQ为正在剪枝的这一层的值               self.WPQ[(key,1)] = self.WPQ[(key,1)][idxs]               self.set_conv(key, num_output=sum(idxs))#改变这一层的num_output数目               t.toc('channel_pruning')           # 将WPQ中的convH设置为新的层           H_params = {'bias':True}           H_params.update(self.infer_pad_kernel(self.WPQ[(conv_H, 0)], conv))           self.set_conv(conv_H, **H_params)          # 将WPQ中的convV代替原来的conv层           V_params = self.infer_pad_kernel(self.WPQ[conv_V], conv)           self.set_conv(conv, new_name=conv_V, **V_params)       new_pt = self.save_pt(prefix=prefix)       return self.WPQ, new_pt

接下来是上述主函数中用到的小函数:

finalmodel函数

def finalmodel(self, WPQ=None, **kwargs):     """ load weights into caffemodel"""    if WPQ is None:        WPQ = self.WPQ    return self.linear(WPQ, **kwargs)    #将WPQ中的权值加载到现在的param中,并存成VHcaffemodel
def linear(self, WPQ, prefix='VH', save=True,DEBUG=0):  #将WPQ中的权值加载到现在的param中,并存成VHcaffemodel    for i, j in WPQ.items():        if save:            self.set_param_data(i, j)#如果保存则直接将j中的参数copy给param(i)        else:            self.ch_param_data(i, j)#如果需要更改参数shape,则先reshape再将j中的参数copy给param(i)    if save:        return self.save_caffemodel(prefix=prefix)#将更新后的param存成带前缀的caffemodel
def preprocess_resnet(self):    #type2names函数是找到网络中所有为该类型的层的名字    sums = self.type2names('Eltwise')  #每个Eltwise是 shortcut的连接处    convs = self.type2names()    ReLUs = self.type2names("ReLU")    projs = {}    WPQ, pt, model = {}, None, self.caffemodel_dir    if dcfgs.model not in [cfgs.Models.xception, cfgs.Models.resnet] or not dcfgs.res.bn:#如果不是xception,resnet,        WPQ, pt, model = self.merge_bn()  #将BN scale层的操作合并进conv层,得到新的网络(WPQ是weight和bias,pt是新的proto,model是caffemodel)    if dcfgs.splitconvrelu:  #true        pt = self.seperateConvReLU()   #之前prototxt中的Relu层的top是conv,现在将top改为relu自己,可以查看生成的pt进行对比    return WPQ, pt, model

extract_features函数

是用来加载step1里准备的数据(frozen.pickle)的函数,它先采样dcfgs.nBatches=500张图片,然后对每层特征图上都采样nPointsPerLayer=10个点,每层特征图都有channel,最终每层的feats_dict的形状为(50000,channel)

    #feats_dict, points_dict = self.extract_features(names=convs, save=1, **kwargs)    def extract_features(self, names=[], nBatches=None, points_dict=None, save=False):        assert nBatches is None, "deprecate"#不赞成        nBatches = dcfgs.nBatches  #500 每个batch采样500张图        nPointsPerLayer=dcfgs.nPointsPerLayer  #每层特征图采样10个点        if not isinstance(names, list): #如果传入的不是list,则转换为list            names = [names]        inner = False        #如果names只有一层,则在fc层上进行操作        if len(names)==1:             for top in self.innerproduct:  #对于每一个fc层                if names[0] in self.bottom_names[top]:#如果names[0]这个conv层是top这个fc层的前一层                    inner = True  #说明有fc层                    nBatches = dcfgs.nBatches_fc                    break        DEBUG = False        pads = dict()        shapes = dict()        feats_dict = dict()#key是每层的name  value为对应的特征图点数据        def set_points_dict(name, data):            assert name not in points_dict            points_dict[name] = data        dcfgs.data = cfgs.Data.lmdb         if save:            if points_dict is None:# true                frozen_points = False                points_dict = dict()                if 0 and self._mem: self.usexyz()                set_points_dict("nPointsPerLayer", nPointsPerLayer)                set_points_dict("nBatches", nBatches)            else:                frozen_points = True                if nPointsPerLayer != points_dict["nPointsPerLayer"] or nBatches != points_dict["nBatches"]:                    print("overwriting nPointsPerLayer, nBatches with frozen_points")                nPointsPerLayer = points_dict["nPointsPerLayer"]                nBatches = points_dict["nBatches"]        assert len(names) > 0        nPicsPerBatch = self.blobs_num(names[0])  #每个batch有多少张图片        nFeatsPerBatch = nPointsPerLayer  * nPicsPerBatch        print("run for", dcfgs.nBatches, "batches", "nFeatsPerBatch", nFeatsPerBatch)#dcfgs.nBatches=500   nFeatsPerBatch=100        nFeats = nFeatsPerBatch * nBatches  # 100 * 500 =  50000        for name in names:            shapes[name] = (self.blobs_height(name), self.blobs_width(name))#conv层的输入尺寸(224,224   112,112   56,56  28,28  14,14 )            if inner or len(self.blobs_shape(name))==2 or ( shapes[name][0] == 1 and shapes[name][1] == 1):                if 0: print(name)                chs = self.blobs_channels(name)                if len(self.blobs_shape(name)) == 4:                    chs*=shapes[name][0]*shapes[name][1]  #计算4维的大小                    #feats_dict可以记录每一个conv的输入                feats_dict[name] = np.ndarray(shape=(nPicsPerBatch * dcfgs.nBatches_fc,chs )) # This dict holds an entry for each conv layer Each dictionary entry will have 5000 rows,            else: #执行这个                                                                                    feats_dict[name] = np.ndarray(shape=(nFeats, self.blobs_channels(name)))  #feats_dict是  conv的名字  conv的festure map    其中shape为(50000,out-channel)               print("Extracting", name, feats_dict[name].shape)         idx = 0        fc_idx = 0        if save:            if not frozen_points:                set_points_dict("data", self.data().shape)                set_points_dict("label", self.label().shape)        runforn = dcfgs.nBatches_fc if dcfgs.dic.fitfc else dcfgs.nBatches        for batch in range(runforn):            if save:                if not frozen_points:                    self.forward()                    set_points_dict((batch, 0), self.data().copy())                    set_points_dict((batch, 1), self.label().copy())                else:                    self.net.set_input_arrays(points_dict[(batch, 0)], points_dict[(batch, 1)])                    self.forward()            else:                self.forward()            for name in names:                # pad = pads[name]                shape = shapes[name]                feat = self.blobs_data(name)                if 0: print(name, self.blobs_shape(name))                if inner or len(self.blobs_shape(name))==2 or (shape[0] == 1 and shape[1] == 1):                    feats_dict[name][fc_idx:(fc_idx + nPicsPerBatch)] = feat.reshape((self.num, -1))                    continue                if batch >= dcfgs.nBatches and name in self.convs:                    continue                # TODO!!! different patch for different image per batch                if save:                    if not frozen_points or (batch, name, "randx") not in points_dict:                        #embed()                        randx = np.random.randint(0, shape[0]-0, nPointsPerLayer)                        randy = np.random.randint(0, shape[1]-0, nPointsPerLayer)                        if dcfgs.dic.option == cfgs.pruning_options.resnet:                            branchrandxy = None                            branch1name = '_branch1'                            branch2cname = '_branch2c'                            if name in self.sums:                                #embed()                                nextblock = self.sums[self.sums.index(name)+1]                                nextb1 = nextblock + branch1name                                if not nextb1 in names:                                    # the previous sum and branch2c will be identical                                    branchrandxy = nextblock + branch2cname                            elif name in self.bns:                                if dcfgs.model == cfgs.Models.xception:                                    branchrandxy = 'interstellar' + name.split('bn')[1].split('_')[0] + branch2cname                                elif dcfgs.model == cfgs.Models.resnet:                                    branchrandxy = 'res' + name.split('bn')[1].split('_')[0] + branch2cname                                    #print("correpondance", branchrandxy)                            if branchrandxy is not None:                                if 0: print('pointsdict of', branchrandxy, 'identical with', name)                                randx = points_dict[(batch, branchrandxy , "randx")]                                randy = points_dict[(batch, branchrandxy , "randy")]                        if name.endswith('_conv1') and dcfgs.dic.option == 1:                            if DEBUG: redprint("this line executed becase dcfgs.dic.option is 1 [net.extract_features()]")                            fsums = ['first_conv'] + self.sums                            blockname = name.partition('_conv1')[0]                            nextb1  = fsums[fsums.index(blockname+'_sum')-1]                            branch1name = blockname + '_proj'                            if branch1name in self.convs:                                nextb1 = branch1name                            randx = points_dict[(batch, nextb1 , "randx")]                            randy = points_dict[(batch, nextb1 , "randy")]                        set_points_dict((batch, name, "randx"), randx.copy())                        set_points_dict((batch, name, "randy"), randy.copy())                    else:                        randx = points_dict[(batch, name, "randx")]                        randy = points_dict[(batch, name, "randy")]                else:                    randx = np.random.randint(0, shape[0]-0, nPointsPerLayer)                    randy = np.random.randint(0, shape[1]-0, nPointsPerLayer)                for point, x, y in zip(range(nPointsPerLayer), randx, randy):                    i_from = idx+point*nPicsPerBatch                    try:                        feats_dict[name][i_from:(i_from + nPicsPerBatch)] = feat[:,:,x, y].reshape((self.num, -1))                    except:                         print('total', runforn, 'batch', batch, 'from', i_from, 'to', i_from + nPicsPerBatch)                         raise Exception("out of bound")                if DEBUG:                    embed()            idx += nFeatsPerBatch            fc_idx += nPicsPerBatch        dcfgs.data = cfgs.Data.lmdb        self.clr_acc()        if save:            if frozen_points:                if points_dict is not None:                    return feats_dict, points_dict                return feats_dict            else:                return feats_dict, points_dict        else:            return feats_dict

VH_decompose

solve中用SVD分解进行特征分解,主要函数为VH_decompose

def VH_decompose(weights, rank=None, DEBUG=0, X=None, Y=None):    """    Param:        weights: n c h w        rank: 最后保留的通道数目    Returns:        V: rank c 1 h  最后将h w的卷积层分解成1 h 和w 1的两个卷积层        H:  n rank w 1    """    dim = weights.shape    VH = np.transpose(weights, [1, 2, 0, 3]) # c h n w    VH = VH.reshape([dim[1]*dim[2], dim[0] * dim[3]])  # ch x nw    #将VH分解为VSH,    #V为ch x ch的左奇异值,S为ch x nw的对角矩阵,由于值存了对角元素所以为ch x 1维,H为nw x nw的右奇异值    V, sigmaVH, H = svd(VH)       if rank is None:#传入的rank小于ch,也就是说传入的rank是要降维的数目        rank = dim[1] * dim[2]  # ch     #相当于PCA降维中,降维到rank维,则只取前rank个特征向量         # ch x rank    V = V[:, :rank]#只取矩阵V的前rank列    # rank x nw    H = H[:rank, :]#只取矩阵H的前rank行    # rank,    sigmaVH = sigmaVH[:rank]    # rank x nw    H = np.diag(sigmaVH).dot(H)#将sigmaVH转化为rank X rank 的对角方阵,然后再与H进行点乘  得到rank x nw的H矩阵    # recover error    # ch x nw -> c h n w   ch x rank点乘rank X nw   得到ch x nw       #再将V与H相乘(实际上是VSH)得到ch x nw的VHr,再将顺序调整回去,也就是降维后的VH    VHr = (V.dot(H)).reshape([dim[1], dim[2], dim[0], dim[3]])     if 0: #DEBUG  计算VH降维后得到VHr,两者之间的error        print('ABS ErrVH', np.mean(np.abs(VHr.flatten()-VH.flatten())))        print('REL ABS ErrVH', np.mean(np.abs(VHr.flatten()-VH.flatten())/ np.abs(VH.flatten())))    # rank nw -> rank n w 1    H = H.reshape([rank, dim[0], dim[3], 1])    # rank n w 1 -> n rank 1 w    H = np.transpose(H, [1, 0, 3, 2])  '''n rank 1 w'''    # ch rank -> c 1 h rank  -> rank c h 1    origV = V.copy()    V = V.reshape((dim[1], 1, dim[2], rank))    V = np.transpose(V, [3, 0, 2, 1])  '''rank c h 1'''    if X is not None:        Xv = np.tensordot(X,V, [[1,2],[1,2]])        Xv = np.transpose(Xv,[0, 2, 3, 1])        N = Xv.shape[0]        o = H.shape[0]        H, b = nonlinear_fc(Xv.reshape([N, -1]), Y)        H = H.reshape([o, rank, 1, 3])        reH = np.transpose(H, [1,0,2,3]).reshape([rank,-1])        VHr = (origV.dot(reH)).reshape([dim[1], dim[2], dim[0], dim[3]])    VHr = np.transpose(VHr, [2, 0, 1, 3])    if 1:        epscheck(V, 2)        epscheck(H, 2)        epscheck(VHr, 2)    if X is not None:        return V, H, VHr, b    return V, H, VHr

dictionary函数

执行channel pruning的主要函数是dictionary():

#dictionary(newX, W2, Y, rank=d_prime, B2=self.param_b_data(Y_name))#alpha是L1的系数,超过1之后,性能会急剧下降,因为模型太过于简单def dictionary(X, W2, Y,alpha=1e-4, rank=None, DEBUG=0, B2=None, rank_tol=.1, verbose=0):    verbose=0    if verbose:        timer = Timer()        timer.tic()    if 0 and rank_tol != dcfgs.dic.rank_tol:        print("rank_tol", dcfgs.dic.rank_tol)    rank_tol = dcfgs.dic.rank_tol    # X: N c h w,  W2: n c h w    norank=dcfgs.autodet  #False    if norank:        rank = None    #TODO remove this    N = X.shape[0]    c = X.shape[1]   #输入fea map的维数    h = X.shape[2]    w=h    n = W2.shape[0]    # TODO I forgot this    # TODO support grp lasso    #暂时不支持组lasso    if h == 1 and False:        for i in range(2):            assert Y.shape[i] == X.shape[i]            pass        grp_lasso = True        mtl = 1    else:        grp_lasso = False    if norank:        alpha = cfgs.alpha / c**dcfgs.dic.layeralpha    if grp_lasso:#False        reX = X.reshape((N, -1))        ally = Y.reshape((N,-1))        samples = np.random.choice(N, N//10, replace=False)        Z = reX[samples].copy()        reY = ally[samples].copy()    else:        samples = np.random.randint(0,N, min(400, N//20))        #samples = np.random.randint(0,N, min(400, N//20))        # c N hw        reX = np.rollaxis(X.reshape((N, c, -1))[samples], 1, 0)        #c hw n        reW2 = np.transpose(W2.reshape((n, c, -1)), [1,2,0])        if dcfgs.dic.alter:            W2_std = np.linalg.norm(reW2.reshape(c, -1), axis=1)        # c Nn        Z = np.matmul(reX, reW2).reshape((c, -1)).T   '''lasso函数是给定Y和X,求解最优的W,本文是求解reY和X*W之间的最优beta。所以这里将X*W得到Z'''        # Nn        reY = Y[samples].reshape(-1)    if grp_lasso:#False        if mtl:            print("solver: group lasso")            _solver = MultiTaskLasso(alpha=alpha, selection='random', tol=1e-1)        else:            _solver = Lasso(alpha=alpha,selection='random' )    elif dcfgs.solver == cfgs.solvers.lightning:   #sklearn == lightning        _solver=CDRegressor(C=1/reY.shape[0]/2, alpha=alpha,penalty='l1', n_jobs=10)    else:'''最后运行的这个'''        _solver = Lasso(alpha=alpha, warm_start=True,selection='random' )'''alpha是L1损失的系数'''    def solve(alpha):        if dcfgs.dic.debug:#0            return np.array(c*[True]), c        _solver.alpha=alpha        _solver.fit(Z, reY) '''z已经是X*W, 所以要拟合的参数是beta'''        #_solver.fit(Z, reY)        if grp_lasso and mtl:            idxs = _solver.coef_[0] != 0.        else:            idxs = _solver.coef_ != 0. '''idxs是Lasso之后beta不为0对应的index'''            if dcfgs.solver == cfgs.solvers.lightning:                idxs=idxs[0]        tmp = sum(idxs)'''idxs是Lasso之后beta不为0的索引,tmp是个数'''        return idxs, tmp    def updateW2(idxs):        nonlocal Z        tmp_r = sum(idxs)        reW2, _ = fc_kernel((X[:,idxs, :,:]).reshape(N, tmp_r*h*w), Y)        reW2 = reW2.T.reshape(tmp_r, h*w, n)        nowstd=np.linalg.norm(reW2.reshape(tmp_r, -1), axis=1)        reW2 = (W2_std[idxs] / nowstd)[:,np.newaxis,np.newaxis] * reW2        newshape = list(reW2.shape)        newshape[0] = c        newreW2 = np.zeros(newshape, dtype=reW2.dtype)        newreW2[idxs, ...] = reW2        Z = np.matmul(reX, newreW2).reshape((c, -1)).T        if 0:            print(_solver.coef_)        return reW2    if rank == c:  #rank为留下来的通道数    c为实际的通道数        idxs = np.array([True] * rank)    elif not norank:  #norank为False   则true        left=0        right=cfgs.alpha #right为想要留下的通道数        lbound = rank# - rank_tol * c        if rank_tol>=1:            rbound = rank + rank_tol        else:            rbound = rank + rank_tol * rank            #rbound = rank + rank_tol * c            if rank_tol == .2:                print("TODO: remove this")                lbound = rank + 0.1 * rank                rbound = rank + 0.2 * rank        while True: #一直lasso拟合,直到不为0的个数tmp满足rank限制            _, tmp = solve(right)'''tmp是Lasso之后beta不为0的个数'''            if False and dcfgs.dic.alter:                if tmp > rank:                    break                else:                    right/=2                    if verbose:print("relax right to",right)            else:                if tmp < rank:#如果不为0的个数小于rank则break                    break                else:#如果未达到rank,则继续solve,right是L1范数的系数,增大会使得模型更多为0                    right*=2                    if verbose:print("relax right to",right)        while True:            alpha = (left+right) / 2            idxs, tmp = solve(alpha)            if verbose:print(tmp, alpha, left, right)            if tmp > rbound:                left=alpha            elif tmp < lbound:                right=alpha            else:                break        if dcfgs.dic.alter:            if rbound == lbound:                rbound +=1            orig_step = left/100 + 0.1 # right / 10            step = orig_step            def waitstable(a):                tmp = -1                cnt = 0                for i in range(10):                    tmp_rank = tmp                    idxs, tmp = solve(a)                    if tmp == 0:                        break                    updateW2(idxs)                    if tmp_rank == tmp:                        cnt+=1                    else:                        cnt=0                    if cnt == 2:                        break                    if 1:                         if verbose:print(tmp, "Z", Z.mean(), "c", _solver.coef_.mean())                return idxs, tmp            previous_Z = Z.copy()            state = 0            statecnt = 0            inc = 10            while True:                Z = previous_Z.copy()                idxs, tmp = waitstable(alpha)                if tmp > rbound:                    if state == 1:                        state = 0                        step/=2                        statecnt=0                    else:                        statecnt+=1                    if statecnt >=2:                        step*=inc                    alpha += step                elif tmp < lbound:                    if state == 0:                        state = 1                        step /= 2                        statecnt=0                    else:                        statecnt+=1                    if statecnt >=2:                        step*=inc                    alpha -= step                else:                    break                if verbose:print(tmp, alpha, 'step', step)        rank=tmp    else:        print("start lasso kernel")        idxs, rank = solve(alpha)        print("end lasso kernel")    if verbose:        timer.toc(show='lasso')        timer.tic()    if grp_lasso:#false        inW, inB = fc_kernel(reX[:, idxs], ally, copy_X=True)        def preconv(a, b, res, org_res):            '''            a: c c'            b: n c h w            res: c            '''            w = np.tensordot(a, b, [[0], [1]])            r = np.tensordot(res, b, [[0], [1]]).sum((1,2)) + org_res            return np.rollaxis(w, 1, 0), r        newW2, newB2 = preconv(inW, W2, inB, B2)    elif dcfgs.ls == cfgs.solvers.lowparams:         reg = LinearRegression(copy_X=True, n_jobs=-1)        assert dcfgs.fc_ridge == 0        assert dcfgs.dic.alter == 0, "Z changed"        reg.fit(Z[:, idxs], reY)#对训练集X, y进行训练        newW2 = reg.coef_[np.newaxis,:,np.newaxis,np.newaxis] * W2[:, idxs, :,:]        newB2 = reg.intercept_    elif dcfgs.nonlinear_fc:#0        newW2, newB2 = nonlinear_fc(X[:,idxs,...].reshape((N,-1)), Y)        newW2 = newW2.reshape((n,rank, h, w))    elif dcfgs.nofc:#0        newW2 = W2[:, idxs, :,:]        newB2 = np.zeros(n)    else:'''执行这个'''        #传入X中beta不为0对应的列,W也是不为0的列        newW2, newB2 = fc_kernel(X[:,idxs,...].reshape((N,-1)), Y, W=W2[:, idxs,...].reshape(n,-1), B=B2)'''对剪枝后的模型进行修正,得到新的w和b'''        newW2 = newW2.reshape((n,rank, h, w))'''将得到新的w展成四维卷积核'''    if verbose:        timer.toc(show='ls')    if not norank:        cfgs.alpha = alpha    if verbose:print(rank)    if DEBUG:        #print(np.where(idxs))        newX = X[:, idxs, ...]#newX为不为0对应的列        return newX, newW2, newB2    else:        return idxs, newW2, newB2

fc_kernel

论文中剪枝分两步走,第一步通过lasso回归求解beta,决定剪枝的channel;
第二步对剪枝后的X W,求解最佳的W,使之重构误差最小。
fc_kernel函数则是求解重构误差最小的W。

def fc_kernel(X, Y, copy_X=True, W=None, B=None, ret_reg=False,fit_intercept=True):    """    return: n c    """    assert copy_X == True    assert len(X.shape) == 2    if dcfgs.ls == cfgs.solvers.gd:          w = Worker()        def wo():            from .GDsolver import fc_GD            a,b=fc_GD(X,Y, W, B, n_iters=1)            return {'a':a, 'b':b}        outputs = w.do(wo)        return outputs['a'], outputs['b']    elif dcfgs.ls == cfgs.solvers.tls:        return tls(X,Y, debug=True)    elif dcfgs.ls == cfgs.solvers.keras:        _reg=keras_kernel()        _reg.fit(X, Y, W, B)        return _reg.coef_, _reg.intercept_    elif dcfgs.ls == cfgs.solvers.lightning:        #_reg = SGDRegressor(eta0=1e-8, intercept_decay=0, alpha=0, verbose=2)        _reg = CDRegressor(n_jobs=-1,alpha=0, verbose=2)        if 0:            _reg.intercept_=B            _reg.coef_=W    elif dcfgs.fc_ridge > 0:  #0        _reg = Ridge(alpha=dcfgs.fc_ridge)    else:'''执行这个   线性回归'''        _reg = LinearRegression(n_jobs=-1 , copy_X=copy_X, fit_intercept=fit_intercept)    _reg.fit(X, Y)'''给定X 和Y,拟合最好的W '''    if ret_reg:        return _reg    return _reg.coef_, _reg.intercept_  '''相当于w和b'''
阅读全文
0 0
原创粉丝点击