scikitlearn/theano多分类问题详解
来源:互联网 发布:js event button 编辑:程序博客网 时间:2024/06/06 09:39
入门先看下面两个网址
二分类:
http://python.jobbole.com/82208/
多分类:
http://blog.csdn.net/han_xiaoyang/article/details/50521072下面说说改进及注意点:
上述博客的数据例子都是根据自己的意思随机生成的。
这边用iris数据进行扩展
由于iris数据是像下面这样的:
5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
。。。。。。。
7.0,3.2,4.7,1.4,Iris-versicolor
6.4,3.2,4.5,1.5,Iris-versicolor
。。。。。。。
6.3,3.3,6.0,2.5,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
7.1,3.0,5.9,2.1,Iris-virginica
一共150行,每行4个特征外加1个类别,共3种类别
所以神经网络作如下设置
class Config:
最后一列代表花的种类,原数据是用名称代替的,这也是我们在多分类中经常碰到的。所以这里要转换为数字。
所以我们在程序里需要进行改动,如下:
t = np.zeros(len(labelMat))
model = build_model(X, t, 5, print_loss=True)
注意这里的0,1,2。我最开始想弄随便弄个什么数,比如1,2,3的,反正都是求概率。后来发现行不通。
最后又仔细研读了下代码
注意程序里有个probs,
这是个概率矩阵,比如这里就是150*3的矩阵
每一行的3个数分别代表了属于三个类的概率。
虽说这3个概率跟具体的类别用什么数字或者符号代替没啥关系。但是注意到,要想获得者三个概率,得用索引来从矩阵probs中获得啊。。。。所以必须从0开始依次增大。。。。
于是想起来以前Theano中一些深度学习的例子。原例子是对mnist进行数字识别的。是多分类。不过很巧的是,它识别的数字种类正好是0,1,2,3,4...是从0开始的,也就没仔细想。
于是又用iris数据集在Theano的DBN上跑了一遍,还是得把类别转换为0,1,2才能跑,不能是其他数字。道理和上面类似。
说到这里,再扯远一点。原来DBN的例子中值是训练加批量测试,没有给出具体的预测/分类方法。
于是我自己改了一下:
在test_DBN()函数中将验证好的模型存为best_DBN_model.pkl。然后在预测函数中,如下设置
DBN_model = pickle.load(open('best_DBN_model.pkl'))
注意这里DBN_model.x这样设置是因为def__init__函数中进行了如下设置
self.x = T.matrix('x')
self.y = T.ivector('y')
x即特征矩阵集,y即标签矩阵集
而这里DBN网络最上面是由一层logistic回归输出的(准确地说是softmax)在logistic_sgd.py文件中有如下定义:
self.y_pred = T.argmax(self.p_y_given_x, axis=1)
这就跟我们上文中神经网络代码中的return np.argmax(probs, axis=1)一样了。
最后再说一下DBN函数中调参数就调下面这段就行了,不用调初始化函数中的了。
dbn = DBN(numpy_rng=numpy_rng, n_ins=4,
其实跟神经网络一样设置。输入输出。以iris数据为例,特征是4个,n_ins=4,种类是3个,n_outs=3。
[10,10,10,10,10,10]代表我们用了6层RBM,每层RBM用的10个节点。
0 0
- scikitlearn/theano多分类问题详解
- LR进行多分类theano代码分析
- theano-多分类逻辑回归代码解析
- Theano和Tensorflow多GPU使用问题
- theano.function、theano.scan 参数数据类型问题
- centos 安装scikitlearn
- 修改了scikitlearn源文件
- XGBoost:多分类问题
- XGBoost:多分类问题
- XGBoost:多分类问题
- 多标签分类问题
- theano与keras安装问题
- svm多分类器详解
- theano
- Theano
- theano
- Theano
- 利用scikitlearn画ROC曲线
- Spark mlib FPGrowth&nb…
- Spark的最短路径详解
- 读书笔记之三十二----《信用…
- 评分卡模型剖析之一(woe、I…
- 数据挖掘技术(四)——聚类
- scikitlearn/theano多分类问题详解
- Weka 分类 注意点
- 深度学习keras程序失败的解决办法
- Java回调函数
- 用lxrun更改Linux子系统中的默认登录帐户
- csdn如何转载别人的文章
- 条款二十四:了解virtual functions、multiple inheritance、virtual base class、runtime type identification的成本
- 搭建web项目结合spring+cxf的webservice服务
- 初学主席树