mxnet下如何查看中间结果
来源:互联网 发布:女王升级数据2017 编辑:程序博客网 时间:2024/05/28 23:12
查看权重
在训练过程中,有时候我们为了debug而需要查看中间某一步的权重信息,在mxnet中,我们可以很方便的调用get_params()方法来得到权重信息。
'''查看权重示例代码转载时注明地址:http://blog.csdn.net/u010414386?viewmode=contents'''import mxnet as mxsym, arg_params, aux_params = mx.model.load_checkpoint('resnet-50',0)#载入模型mod = mx.mod.Module(symbol=sym,context=mx.gpu()) #创建Modulemod.bind(for_training=False,data_shapes=[('data',(1,3,224,224))]) #绑定,此代码为预测代码,所以training参数设为Falsemod.set_params(arg_params,aux_params)import numpy as npimport cv2def get_image(filename): img = cv2.imread(filename) img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB) img = cv2.resize(img,(224,224)) img = np.swapaxes(img,0,2) img = np.swapaxes(img,1,2) img = img[np.newaxis,:] return imgfrom collections import namedtupleBatch = namedtuple('Batch',['data'])img = get_image('val_1000/0.jpg') #获取图片mod.forward(Batch([mx.nd.array(img)])) #预测结果#################################################debug模式下,获取权重信息keys = mod.get_params()[0].keys() # 列出所有权重名称conv_w = mod.get_params()[0]['conv0_weight'] #获取想要查看的权重信息,如conv_weightprint conv_w.asnumpy() #查看具体数值################################################prob = mod.get_outputs()[0].asnumpy()y = np.argsort(np.squeeze(prob))[::-1]print('truth label %d; top-1 predict label %d' % (val_label[0], y[0]))
查看中间输出结果
由于mxnet的网络由symbol组成,而symbol又属于符号式编程,所以我们不能像上面查看权重一样直接查看,我们需要把我们想看的输出结果保存下来。
'''方法一查看中间结果代码转载时注明地址:http://blog.csdn.net/u010414386?viewmode=contents'''import mxnet as mxnet = mx.symbol.Variable('data')fc1 = mx.symbol.FullyConnected(data=net, name='fc1', num_hidden=128)net = mx.symbol.Activation(data=fc1, name='relu1', act_type="relu")net = mx.symbol.FullyConnected(data=net, name='fc2', num_hidden=64)out = mx.symbol.SoftmaxOutput(data=net, name='softmax')# 通过把两个输出组成一个group来得到自己需要查看的中间层输出结果group = mx.symbol.Group([fc1, out]) print group.list_outputs()
'''方法二有时候我们使用别人的模型,所以无法像方法一一样在定义模型的时候就确定需要查看的中间层输出结果,这时候我们使用get_internals()方法来查找自己需要查看的中间层转载时注明地址:http://blog.csdn.net/u010414386?viewmode=contents'''import mxnet as mxsym, arg_params, aux_params = mx.model.load_checkpoint('resnet-50',0)#载入模型########################################################################args = sym.get_internals().list_outputs() #获得所有中间输出internals = model.symbol.get_internals()fc1 = internals['fc1_output']conv = internals['stage4_unit3_conv1_output']group = mx.symbol.Group([fc1, sym, conv]) #把需要输出的结果按group方式组合起来,这样就可以得到中间层的输出#########################################################################mod = mx.mod.Module(symbol=group,context=mx.gpu()) #创建Modulemod.bind(for_training=False,data_shapes=[('data',(1,3,224,224))]) #绑定,此代码为预测代码,所以training参数设为Falsemod.set_params(arg_params,aux_params)import numpy as npimport cv2def get_image(filename): img = cv2.imread(filename) img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB) img = cv2.resize(img,(224,224)) img = np.swapaxes(img,0,2) img = np.swapaxes(img,1,2) img = img[np.newaxis,:] return imgfrom collections import namedtupleBatch = namedtuple('Batch',['data'])img = get_image('val_1000/0.jpg') #获取图片mod.forward(Batch([mx.nd.array(img)])) #预测结果prob = mod.get_outputs()[0].asnumpy()y = np.argsort(np.squeeze(prob))[::-1]print('truth label %d; top-1 predict label %d' % (val_label[0], y[0]))
0 0
- mxnet下如何查看中间结果
- pg如何保存中间结果
- PyTorch学习总结(一)——查看模型中间结果
- 如何查看JOB的结果
- 如何查看Navicat 查询结果
- mxnet 训练--如何生成rec 数据 +自己在本机测试的结果
- Windows10下安装Mxnet
- Ubuntu下安装MxNet
- C++中间结果溢出
- 查看结果。。
- Modelsim查看中间变量
- windows下安装配置Mxnet
- Win10下MxNet安装手记
- Ubuntu14.04下MXNet安装
- mxnet 在windows下安装
- ubantu下mxnet版本更新
- hive中间结果和结果的压缩
- hive 压缩 最终结果 中间结果
- C#笔记9——基于TableLayoutPanel的多分屏、全屏程序
- PX4 Windows 编译环境配置
- Android基础知识总结
- VMware 中 Linux 修改 ip 方法 亦可解决 eth0消失 的情况
- Why does my Authorize Attribute not work-
- mxnet下如何查看中间结果
- 分析rusty-blockparser工具生成的解析文件含义
- 半年总结
- 微信小程序 时间戳转换
- DP —— 玲珑学院OJ 1091
- 为ViewFlipper添加点击事件,很简单
- [java-代理]测试Proxy和Enhancer两种代理方式
- 蓝桥杯寒假任务之出现次数最多的整数
- 【原创】Centos6中yum方法安装sl(linux有趣命令之一sl跑火车)