mxnet系列 tools 查看params的内容

来源:互联网 发布:久量led台灯 知乎 编辑:程序博客网 时间:2024/06/15 04:33

caffe自己有写查看模型的内容

mxnet自己也写了一个

import mxnet as mximport pdbdef load_checkpoint(prefix, epoch):    """    Load model checkpoint from file.    :param prefix: Prefix of model name.    :param epoch: Epoch number of model we would like to load.    :return: (arg_params, aux_params)    arg_params : dict of str to NDArray        Model parameter, dict of name to NDArray of net's weights.    aux_params : dict of str to NDArray        Model parameter, dict of name to NDArray of net's auxiliary states.    """    save_dict = mx.nd.load('%s-%04d.params' % (prefix, epoch))    arg_params = {}    aux_params = {}    for k, v in save_dict.items():        tp, name = k.split(':', 1)        if tp == 'arg':            arg_params[name] = v        if tp == 'aux':            aux_params[name] = v    return arg_params, aux_paramsdef convert_context(params, ctx):    """    :param params: dict of str to NDArray    :param ctx: the context to convert to    :return: dict of str of NDArray with context ctx    """    new_params = dict()    for k, v in params.items():        new_params[k] = v.as_in_context(ctx)    #print new_params[0]    return new_paramsdef load_param(prefix, epoch, convert=False, ctx=None):    """    wrapper for load checkpoint    :param prefix: Prefix of model name.    :param epoch: Epoch number of model we would like to load.    :param convert: reference model should be converted to GPU NDArray first    :param ctx: if convert then ctx must be designated.    :return: (arg_params, aux_params)    """    arg_params, aux_params = load_checkpoint(prefix, epoch)    if convert:        if ctx is None:            ctx = mx.cpu()        arg_params = convert_context(arg_params, ctx)        aux_params = convert_context(aux_params, ctx)    return arg_params, aux_paramsif __name__=='__main__':        result =  load_param('my_',1);        #pdb.set_trace()        print 'result is'        print result        print 'one of results is:'        print result[0]['fc2_weight'].asnumpy()


0 0