MXnet查看参数的权值
来源:互联网 发布:达内培训 it外企it 编辑:程序博客网 时间:2024/06/17 13:42
我们用MXnet训练好模型之后,有时想看看其中参数的权值,可以用
model.get_params()函数,具体的操作见下面的例子。
import mxnet as mximport numpy as npimport logginglogging.getLogger().setLevel(logging.DEBUG)# Training datatrain_data = np.random.uniform(0, 1, [100, 2])train_label = np.array([train_data[i][0] + 2 * train_data[i][1] for i in range(100)])batch_size = 3# Evaluation Dataeval_data = np.array([[7,2],[6,10],[12,2]])eval_label = np.array([11,26,16])train_iter = mx.io.NDArrayIter(train_data,train_label, batch_size, shuffle=True,label_name='lin_reg_label')eval_iter = mx.io.NDArrayIter(eval_data, eval_label, batch_size, shuffle=False)X = mx.sym.Variable('data')Y = mx.symbol.Variable('lin_reg_label')fully_connected_layer = mx.sym.FullyConnected(data=X, name='fc1', num_hidden = 1)lro = mx.sym.LinearRegressionOutput(data=fully_connected_layer, label=Y, name="lro")model = mx.mod.Module( symbol = lro , data_names=['data'], label_names = ['lin_reg_label'] # network structure)mx.viz.plot_network(symbol=lro)model.fit(train_iter, eval_iter, optimizer_params={'learning_rate':0.01, 'momentum': 0.9}, num_epoch=50, eval_metric='mse', batch_end_callback = mx.callback.Speedometer(batch_size, 2))model.predict(eval_iter).asnumpy()metric = mx.metric.MSE()model.score(eval_iter, metric)keys = model.get_params()[0].keys() # 列出所有权重名称print(keys)conv_w = model.get_params()[0]['fc1_weight'] # 获取想要查看的权重信息bias = model.get_params()[0]['fc1_bias']print(conv_w.asnumpy()) # 查看具体数值print(bias.asnumpy())
阅读全文
0 0
- MXnet查看参数的权值
- mxnet系列 tools 查看params的内容
- 查看caffemodel的参数值
- mxnet显示层参数代码
- 查看Hibernate参数值的变通方法
- MXNet的Model API
- MXNet的模型园地
- mxnet的更新问题
- mshadow的原理--MXNet
- mxnet 框架的搭建
- mxnet中,SGD(随机梯度下降)的参数momentum的用处
- mxnet下如何查看中间结果
- mxnet
- MXNet
- MXNet
- MXNet
- 查看程序的启动参数,入口参数
- 查看PC的各种参数
- Jzoj1950 拉拉队排练
- 搭建Lamp之安装PHP5.6
- C++11 error: unable to find string literal operator 'operator"
- Docker-Compose简介安装使用
- centos7安装:license information
- MXnet查看参数的权值
- sql语言:DQL、DML、DDL、DCL
- 欢迎使用CSDN-markdown编辑器
- 修改jar中包结构,修改jar包包名
- 虽然很短暂,但也曾经有过
- sqlite数据库在Python中的使用简介
- opencv: cv2.flip 图像翻转 进行 数据增强
- 基于kotlin实现的简单个人账户管理APP
- tomcat设置使得url省去项目名称