使用pycaffe 对caffemodel 参数进行读取

来源:互联网 发布:人工智能介绍 编辑:程序博客网 时间:2024/06/05 11:47
<code class="language-python hljs  has-numbering"><span class="hljs-comment">#!/usr/bin/env python</span><span class="hljs-comment"># 引入“咖啡”</span><span class="hljs-keyword">import</span> caffe<span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np<span class="hljs-comment"># 使输出的参数完全显示</span><span class="hljs-comment"># 若没有这一句,因为参数太多,中间会以省略号“……”的形式代替</span>np.set_printoptions(threshold=<span class="hljs-string">'nan'</span>)<span class="hljs-comment"># deploy文件</span>MODEL_FILE = <span class="hljs-string">'caffe_deploy.prototxt'</span><span class="hljs-comment"># 预先训练好的caffe模型</span>PRETRAIN_FILE = <span class="hljs-string">'caffe_iter_10000.caffemodel'</span><span class="hljs-comment"># 保存参数的文件</span>params_txt = <span class="hljs-string">'params.txt'</span>pf = open(params_txt, <span class="hljs-string">'w'</span>)<span class="hljs-comment"># 让caffe以测试模式读取网络参数</span>net = caffe.Net(MODEL_FILE, PRETRAIN_FILE, caffe.TEST)<span class="hljs-comment"># 遍历每一层</span><span class="hljs-keyword">for</span> param_name <span class="hljs-keyword">in</span> net.params.keys():    <span class="hljs-comment"># 权重参数</span>    weight = net.params[param_name][<span class="hljs-number">0</span>].data    <span class="hljs-comment"># 偏置参数</span>    bias = net.params[param_name][<span class="hljs-number">1</span>].data    <span class="hljs-comment"># 该层在prototxt文件中对应“top”的名称</span>    pf.write(param_name)    pf.write(<span class="hljs-string">'\n'</span>)    <span class="hljs-comment"># 写权重参数</span>    pf.write(<span class="hljs-string">'\n'</span> + param_name + <span class="hljs-string">'_weight:\n\n'</span>)    <span class="hljs-comment"># 权重参数是多维数组,为了方便输出,转为单列数组</span>    weight.shape = (-<span class="hljs-number">1</span>, <span class="hljs-number">1</span>)    <span class="hljs-keyword">for</span> w <span class="hljs-keyword">in</span> weight:        pf.write(<span class="hljs-string">'%ff, '</span> % w)    <span class="hljs-comment"># 写偏置参数</span>    pf.write(<span class="hljs-string">'\n\n'</span> + param_name + <span class="hljs-string">'_bias:\n\n'</span>)    <span class="hljs-comment"># 偏置参数是多维数组,为了方便输出,转为单列数组</span>    bias.shape = (-<span class="hljs-number">1</span>, <span class="hljs-number">1</span>)    <span class="hljs-keyword">for</span> b <span class="hljs-keyword">in</span> bias:        pf.write(<span class="hljs-string">'%ff, '</span> % b)    pf.write(<span class="hljs-string">'\n\n'</span>)pf.close</code>
一定要注意把该包含的库都加入到python 的路径当中,否则会出错。
0 0
原创粉丝点击