MXNet获取特征输出
来源:互联网 发布:淘宝类目官方群 编辑:程序博客网 时间:2024/05/21 06:31
在网上下载的MXNet预训练模型常常是完整的,但是在实际应用中,我们一般只需要网络中某一层作为特征提取,这个时候就需要重建模型,使得网络最后的输出是特征。
加载预训练模型
加载模型使用model.FeedForward.load
就可以了,后面的参数分别是模型的名字、迭代次数和batch大小,需要根据实际模型进行修改:
import mxnet as mximport numpy as npmodel=mx.model.FeedForward.load('model_name',1,num_batch_size=1)
找到特征层
别人训练好的模型我们常常不知道有哪些层,这时候需要列出所有的层,以便于我们找到特征层:
internals=model.symbol.get_internals() #list all symbolinternals.list_outputs()
列出网络中所有的层,像这样:
['data', 'conv1_weight', 'conv1_bias', 'conv1_output', 'slice1_output0', 'slice1_output1', '_maximum0_output', …… …… 'slice_fc1_output0', 'slice_fc1_output1', '_maximum9_output', 'drop1_output', 'fc2_weight', 'fc2_bias', 'fc2_output', 'softmax_label', 'softmax_output']
重建符号与模型
比如我们要把drop1_output
作为输出特征:
fea_symbol=internals['drop1_output'] #choose feature layerfeature_extractor=mx.model.FeedForward(symbol=fea_symbol,numpy_batch_size=1,arg_params=model.arg_params,aux_params=model.aux_params,allow_extra_params=True)
这里的feature_extractor
就是我们的新模型。
重新保存
得到的新模型当然要保存下来:
feature_extractor.save('new_model_name',1) #save new symbol and model
模型名字要根据自己的情况做一些修改。
测试一下:
new_model=mx.model.FeedForward.load('new_model_name',1,num_batch_size=1)
完整代码
完整代码在我的GitHub上:
https://github.com/flyingzhao/RebuildModel
相关issue:
https://github.com/dmlc/mxnet/issues/3883
1 0
- MXNet获取特征输出
- mxnet修改网络输出num
- 特征获取
- mxnet
- MXNet
- MXNet
- MXNet
- 获取机器特征
- 使用mxnet的预训练模型(pretrained model)分类与特征提取
- SIFT学习--特征点获取
- 获取机器特征码程序
- ArcGIS教程:获取径流特征
- 深度学习框架哪家强?MXNet称霸CNN、RNN和情感分析,TensorFlow仅擅长推断特征提取
- 深度学习框架哪家强?MXNet称霸CNN、RNN和情感分析,TensorFlow仅擅长推断特征提取
- 深度学习框架哪家强?MXNet称霸CNN、RNN和情感分析,TensorFlow仅擅长推断特征提取
- 深度学习框架哪家强?MXNet称霸CNN、RNN和情感分析,TensorFlow仅擅长推断特征提取
- [C#] 获取实时输出
- PERL: 获取system输出
- #486 – InkCanvas 支持多种编辑模式(InkCanvas Supports Different Editing Modes)
- sys_context()函数用法解析
- 项目管理利器(Maven)——pom.xml解析
- RJ45接口定义
- A1046. Shortest Distance (20)
- MXNet获取特征输出
- PCM(44字节)的Wav文件头及其相关的编程方法
- VS2010开发的winform程序在XP系统打不开的原因(与ico图标像素有关)
- cortex-m3 各种引发fault的统计
- soot基础 -- 常用参数配置
- Java调用XML的方法:DocumentBuilderFactory
- 内存泄漏和内存溢出区别
- 如何引用本地aar文件
- Error:Unable to tunnel through proxy. Proxyreturns "HTTP/1.1 400 Bad Request"